Ejemplo n.º 1
0
class MotionData(object):
    def __init__(self, num_length):
        self.test_client = TestClient()
        #self.data_array = np.zeros(num_length)
        #self.data_array = np.zeros(num_length)
        self.data_array = [None] * num_length
        self.num_length = num_length

    def receive_data(self, data):
        #print( "Received data from client", data)
        rec_num = data

        #self.data_array[2:] = self.data_array[1:]

        #make a running buffer
        with mutex:
            self.data_array.insert(0, rec_num)
            self.data_array.pop()

        #save data to a

    def start(self):
        self.test_client.dataListener = self.receive_data
        self.test_client.run()
        print('Start the interface thread')

    def stop(self):
        pass

    def get(self):
        current_value = None
        with mutex:
            current_value = self.data_array[0]
        #return the latest saved data
        return current_value
Ejemplo n.º 2
0
    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()

        # Creating the 1st user. This user will issue a request
        user = User(username=self.default_username)
        user.set_password_hash(self.default_password)
        db.session.add(user)
        db.session.commit()
        self.client = TestClient(self.app, user.generate_auth_token(), '')

        # Create a request for 1st user
        request_data = {
            'meal_type': 'Vietnamese',
            'meal_time': 'Dinner',
            'location_string': 'San Francisco'
        }
        rv, json = self.client.post(API_VERSION + '/requests/',
                                    data=request_data)
        self.request_location = rv.headers['Location']

        # Create the 2nd user. This user will make proposal for the request by
        # 1st user
        user_2 = User(username='******')
        user_2.set_password_hash('123456')
        db.session.add(user_2)
        db.session.commit()
        self.client = TestClient(self.app, user_2.generate_auth_token(), '')
Ejemplo n.º 3
0
 def set_environment(self, env):
     """
     Change the Square App Environmnet.
     Requires re-instatiating TestClient and setting env var.
     """
     os.environ['TAP_SQUARE_ENVIRONMENT'] = env
     self.client = TestClient(env=env)
     self.SQUARE_ENVIRONMENT = env
Ejemplo n.º 4
0
 def setUp(self):
     self.app = app
     self.ctx = self.app.app_context()
     self.ctx.push()
     db.drop_all()
     db.create_all()
     u = User(username=self.default_username)
     u.set_password(self.default_password)
     db.session.add(u)
     db.session.commit()
     self.client = TestClient(self.app, u.generate_auth_token(), '')
Ejemplo n.º 5
0
    def setUp(self):
        self.app = create_app('testing')
        self.app_context = self.app.app_context()
        self.app_context.push()
        db.create_all()
        u = User(username=self.default_username,
                 password=self.default_password)
        db.session.add(u)
        db.session.commit()

        self.client = TestClient(self.app, u.generate_auth_token(), '')
Ejemplo n.º 6
0
 def tearDownClass(cls):
     cls.set_environment(cls, cls.SANDBOX)
     cleanup = {'categories': 10000}
     client = TestClient(env=os.environ['TAP_SQUARE_ENVIRONMENT'])
     for stream, limit in cleanup.items():
         print("Checking if cleanup is required.")
         all_records = client.get_all(stream, start_date=cls.STATIC_START_DATE)
         all_ids = [rec.get('id') for rec in all_records if not rec.get('is_deleted')]
         if len(all_ids) > limit / 2:
             chunk = int(len(all_ids) - (limit / 2))
             print("Cleaning up {} excess records".format(chunk))
             client.delete_catalog(all_ids[:chunk])
Ejemplo n.º 7
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.start_date = self.get_properties().get("start_date")
     self.client = TestClient({
         **self.get_properties(),
         **self.get_credentials(),
     })
Ejemplo n.º 8
0
 def get_credentials(self):
     token = TestClient.get_token_information()
     return {
         'refresh_token': token['refresh_token'],
         'client_id': token['client_id'],
         'client_secret': token['client_secret'],
         'access_token': token['access_token']
     }
Ejemplo n.º 9
0
 def setUp(self):
     self.app = app
     self.ctx = self.app.app_context()
     self.ctx.push()
     db.drop_all()
     db.create_all()
     u = User(username=self.default_username)
     u.set_password(self.default_password)
     db.session.add(u)
     db.session.commit()
     self.client = TestClient(self.app, u.generate_auth_token(), '')
Ejemplo n.º 10
0
class AuthTest(test.TestCase):

    def setUp(self):
        self.auth_provider = BasketAuthentication()
        self.c = TestClient()

    def test_valid_auth(self):
        r = self.c.subscribe_request(email='*****@*****.**')
        eq_(self.auth_provider.is_authenticated(r), True)

    def test_invalid_auth(self):
        self.c.consumer = oauth.Consumer('fail', 'fail')
        r = self.c.subscribe_request()
        eq_(self.auth_provider.is_authenticated(r), False)

    def test_HTTP_AUTHORIZATION_header(self):
        r = self.c.subscribe_request(email='*****@*****.**')
        header = r.META.pop('Authorization')
        r.META['HTTP_AUTHORIZATION'] = header
        eq_(self.auth_provider.is_authenticated(r), True)
Ejemplo n.º 11
0
    def setUp(self):
        self.app = create_app('testing')

        # add an additional route used only in tests
        @self.app.route('/foo')
        @async
        def foo():
            1 / 0

        @self.app.route('/test-long')
        @async
        def long_task():
            sleep(10)
            return jsonify({}), 200

        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()  # just in case
        db.create_all()
        u = User(username=self.default_username,
                 password=self.default_password)

        self.client = TestClient(self.app, u.generate_auth_token(),'')
Ejemplo n.º 12
0
def run_deepweb_test():
    t = TestClient(service_host="http://0.0.0.0:8080", auth=('admin', 'admin'))

    test_files_dict = {
        "Hidden Service lists and search engines": "test_link_sites.json",
        "Marketplace Financial": "test_finance.json",
        "Marketplace Commercial Services": "test_commercial_services.json",
        "Blogs and radios": "test_blogs.json",
        "Politics": "test_politics.json"
    }

    for tests_name, tests_file in test_files_dict.items():
        print(
            "======================================================================"
        )
        print("Website group: %s from file %s" % (tests_name, tests_file))
        print(
            "======================================================================"
        )
        with open(tests_file) as f:
            tests_json = json.load(f)
            # Get archives for each individual page
            for web_name, url in tests_json.items():
                print("Website: %s, URL: %s" % (web_name, url))
                data = {
                    'urls': [url],
                    'name': "DeepWebTest-{}".format(t.test_number),
                    'headers': {},
                    'forceTor': True,
                }

                timing_start = time.time()
                t.get_archive(data=data)
                timing_end = time.time()
                print("  archiving request time: %s" %
                      (timing_end - timing_start))
Ejemplo n.º 13
0
class TestAPI(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()
        u = User(username=self.default_username)
        u.set_password(self.default_password)
        db.session.add(u)
        db.session.commit()
        self.client = TestClient(self.app, u.generate_auth_token(), '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_customers(self):
        # get empty list of customers
        rv, json = self.client.get('/customers/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['customers'] == [])

        # add the customer
        rv, json = self.client.post('/customers/', data={'name': 'john'})
        self.assertTrue(rv.status_code == 201)
        location = rv.headers['Location']
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'john')
        rv, json = self.client.get('/customers/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['customers'] == [location])

        # edit the customer
        rv, json = self.client.put(location, data={'name': 'John Smith'})
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'John Smith')
Ejemplo n.º 14
0
class TestAPI(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()
        u = User(username=self.default_username)
        u.set_password(self.default_password)
        db.session.add(u)
        db.session.commit()
        self.client = TestClient(self.app, u.generate_auth_token(), '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_customers(self):
        # get list of customers
        rv, json = self.client.get('/customers/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['customers'] == [])

        # add a customer
        rv, json = self.client.post('/customers/', data={'name': 'john'})
        self.assertTrue(rv.status_code == 201)
        location = rv.headers['Location']
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'john')
        rv, json = self.client.get('/customers/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['customers'] == [location])

        # edit the customer
        rv, json = self.client.put(location, data={'name': 'John Smith'})
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'John Smith')
Ejemplo n.º 15
0
class CeleryTestCase(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = create_app('testing')

        # add an additional route used only in tests
        @self.app.route('/foo')
        @async
        def foo():
            1 / 0

        @self.app.route('/test-long')
        @async
        def long_task():
            sleep(10)
            return jsonify({}), 200

        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()  # just in case
        db.create_all()
        u = User(username=self.default_username,
                 password=self.default_password)

        self.client = TestClient(self.app, u.generate_auth_token(),'')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_celery(self):
        response, json_response = self.client.get('/test-long')
        self.assertEquals(response.status_code,202)
Ejemplo n.º 16
0
 def setUp(self):
     self.auth_provider = BasketAuthentication()
     self.c = TestClient()
Ejemplo n.º 17
0
"""
This is used for testing basic functionality of the test client.
To run change the desired flags below and use the following command from the tap-tester repo:
  'python ../tap-square/tests/client_tester.py'
"""
from datetime import datetime

from test_client import TestClient

##########################################################################
# Testing the TestCLient
##########################################################################
if __name__ == "__main__":
    client = TestClient(env='sandbox')
    # START_DATE = '2020-06-24T00:00:00Z'
    START_DATE = datetime.strftime(datetime.utcnow(), '%Y-%m-%dT00:00:00Z')

    # CHANGE FLAGS HERE TO TEST SPECIFIC FUNCTION TYPES
    test_creates = True
    test_updates = False  # To test updates, must also test creates
    test_gets = False
    test_deletes = True  # To test deletes, must also test creates

    # CHANGE FLAG TO PRINT ALL OBJECTS THAT FUNCTIONS INTERACT WITH
    print_objects = True

    objects_to_test = [  # CHANGE TO TEST DESIRED STREAMS
        # 'modifier_lists', # GET - DONE | CREATE -  | UPDATE -  | DELETE -
        # 'inventories', # GET - DONE | CREATE - DONE | UPDATE - DONE | DELETE -
        # 'items',  # GET - DONE | CREATE - DONE | UPDATE - DONE | DELETE -
        'categories',  # GET - DONE | CREATE - DONE | UPDATE - DONE | DELETE - DONE
Ejemplo n.º 18
0
class SubscriptionTest(test.TestCase):
    def setUp(self):
        self.c = TestClient()
        self.count = Subscription.objects.count

    def tearDown(self):
        self.c.tearDown()

    def valid_subscriber(self):
        return Subscriber(email='*****@*****.**')

    def valid_subscription(self):
        subscriber = self.valid_subscriber()
        subscriber.save()
        s = Subscription(campaign='test')
        s.subscriber = subscriber
        return s

    def test_subscriber_validation(self):
        # test valid subscriber
        a = self.valid_subscriber()
        a.save()

        # fail on blank email
        a.email = ''
        self.assertRaises(ValidationError, a.full_clean)

        # fail on bad email format
        a.email = 'foo'
        self.assertRaises(ValidationError, a.full_clean)

    def test_subscription_validation(self):
        # test valid subscription
        a = self.valid_subscription()
        a.save()

        # fail on blank campaign
        b = self.valid_subscription()
        b.campaign = ''
        self.assertRaises(ValidationError, b.full_clean)
  
        # fail on bad locale
        c = self.valid_subscription()
        c.locale = 'foo'
        self.assertRaises(ValidationError, c.full_clean)

    def test_locale_fallback(self):
        """A blank locale will fall back to en-US"""

        a = self.valid_subscription()
        a.locale = ''
        a.full_clean()
        eq_(a.locale, 'en-US')

    def test_valid_locale(self):
        """A valid locale, other than en-US, works"""

        a = self.valid_subscription()
        a.locale = 'es-ES'
        a.full_clean()
        a.save()
        eq_(Subscription.objects.filter(locale='es-ES').count(), 1)

    def test_active_default(self):
        """A new record is active be default"""

        a = self.valid_subscription()
        a.save()
        eq_(a.active, True)

    def test_status_codes(self):
        # validation errors return 400
        resp = self.c.subscribe(email='')
        eq_(resp.status_code, 400, resp.content)
        eq_(self.count(), 0)
        
        # success returns 200
        resp = self.c.subscribe(email='*****@*****.**', campaign='foo')
        eq_(resp.status_code, 201, resp.content)
        eq_(self.count(), 1)
        eq_(Subscription.objects.filter(subscriber__email='*****@*****.**').count(), 1)

        # duplicate returns 409
        resp = self.c.subscribe(email='*****@*****.**', campaign='foo')
        eq_(resp.status_code, 409, resp.content)
        eq_(self.count(), 1)
        eq_(Subscription.objects.filter(subscriber__email='*****@*****.**').count(), 1)

    def test_read(self):
        resp = self.c.read()
        eq_(resp.status_code, 501)
Ejemplo n.º 19
0
    class TestSquareBase(ABC, TestCase):
        PRODUCTION = "production"
        SANDBOX = "sandbox"
        SQUARE_ENVIRONMENT = SANDBOX
        REPLICATION_KEYS = "valid-replication-keys"
        PRIMARY_KEYS = "table-key-properties"
        REPLICATION_METHOD = "forced-replication-method"
        START_DATE_KEY = 'start-date-key'
        INCREMENTAL = "INCREMENTAL"
        FULL = "FULL_TABLE"
        START_DATE_FORMAT = "%Y-%m-%dT00:00:00Z"
        STATIC_START_DATE = "2020-07-13T00:00:00Z"
        START_DATE = ""
        PRODUCTION_ONLY_STREAMS = {'roles', 'bank_accounts', 'settlements'}

        DEFAULT_BATCH_LIMIT = 1000
        API_LIMIT = {
            'items': DEFAULT_BATCH_LIMIT,
            'inventories': 100,
            'categories': DEFAULT_BATCH_LIMIT,
            'discounts': DEFAULT_BATCH_LIMIT,
            'taxes': DEFAULT_BATCH_LIMIT,
            'cash_drawer_shifts': DEFAULT_BATCH_LIMIT,
            'employees': 50,
            'locations':
            None,  # Api does not accept a cursor and documents no limit, see https://developer.squareup.com/reference/square/locations/list-locations
            'roles': 100,
            'refunds': 100,
            'payments': 100,
            'customers': 100,
            'modifier_lists': DEFAULT_BATCH_LIMIT,
            'orders': 500,
            'shifts': 200,
            'settlements': 200,
        }

        def setUp(self):
            missing_envs = [
                x for x in [
                    "TAP_SQUARE_REFRESH_TOKEN",
                    "TAP_SQUARE_APPLICATION_ID",
                    "TAP_SQUARE_APPLICATION_SECRET",
                ] if os.getenv(x) is None
            ]
            if missing_envs:
                raise Exception(
                    "Missing environment variables: {}".format(missing_envs))

            # Allows diffs in asserts to print more
            self.maxDiff = None
            self.set_environment(
                self.SANDBOX)  # We always want the tests to start in sandbox

        @staticmethod
        def get_type():
            return "platform.square"

        @staticmethod
        def tap_name():
            return "tap-square"

        def set_environment(self, env):
            """
            Change the Square App Environmnet.
            Requires re-instatiating TestClient and setting env var.
            """
            os.environ['TAP_SQUARE_ENVIRONMENT'] = env
            self.client = TestClient(env=env)
            self.SQUARE_ENVIRONMENT = env

        @staticmethod
        def get_environment():
            return os.environ['TAP_SQUARE_ENVIRONMENT']

        def get_properties(self, original=True):
            # Default values
            return_value = {
                'start_date':
                dt.strftime(dt.utcnow() - timedelta(days=3),
                            self.START_DATE_FORMAT),
                'sandbox':
                'true' if self.get_environment() == self.SANDBOX else 'false'
            }

            if not original:
                return_value['start_date'] = self.START_DATE

            return return_value

        @staticmethod
        def get_credentials():
            environment = os.getenv('TAP_SQUARE_ENVIRONMENT')
            if environment in ['sandbox', 'production']:
                creds = {
                    'refresh_token':
                    os.getenv('TAP_SQUARE_REFRESH_TOKEN')
                    if environment == 'sandbox' else
                    os.getenv('TAP_SQUARE_PROD_REFRESH_TOKEN'),
                    'client_id':
                    os.getenv('TAP_SQUARE_APPLICATION_ID')
                    if environment == 'sandbox' else
                    os.getenv('TAP_SQUARE_PROD_APPLICATION_ID'),
                    'client_secret':
                    os.getenv('TAP_SQUARE_APPLICATION_SECRET')
                    if environment == 'sandbox' else
                    os.getenv('TAP_SQUARE_PROD_APPLICATION_SECRET'),
                }
            else:
                raise Exception(
                    "Square Environment: {} is not supported.".format(
                        environment))

            return creds

        def expected_check_streams(self):
            return set(self.expected_metadata().keys()).difference(set())

        def expected_metadata(self):
            """The expected streams and metadata about the streams"""

            all_streams = {
                "bank_accounts": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                },
                "cash_drawer_shifts": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                },
                "customers": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "categories": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "discounts": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "employees": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                    self.START_DATE_KEY: 'updated_at'
                },
                "inventories": {
                    self.PRIMARY_KEYS: set(),
                    self.REPLICATION_METHOD: self.FULL,
                    self.START_DATE_KEY: 'calculated_at',
                },
                "items": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "locations": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                },
                "modifier_lists": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "orders": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "payments": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                },
                "refunds": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                },
                "roles": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                    self.START_DATE_KEY: 'updated_at'
                },
                "settlements": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.FULL,
                },
                "shifts": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
                "taxes": {
                    self.PRIMARY_KEYS: {'id'},
                    self.REPLICATION_METHOD: self.INCREMENTAL,
                    self.REPLICATION_KEYS: {'updated_at'}
                },
            }

            if self.get_environment() == self.SANDBOX:
                return {
                    k: v
                    for k, v in all_streams.items()
                    if k not in self.PRODUCTION_ONLY_STREAMS
                }
            return all_streams

        def expected_replication_method(self):
            """return a dictionary with key of table name and value of replication method"""
            return {
                table: properties.get(self.REPLICATION_METHOD, None)
                for table, properties in self.expected_metadata().items()
            }

        @staticmethod
        def production_streams():
            """Some streams can only have data on the production app. We must test these separately"""
            return {
                'employees',
                'roles',
                'bank_accounts',
                'settlements',
            }

        def sandbox_streams(self):
            """By default we will be testing streams in the sandbox"""
            return self.expected_streams().difference(
                self.production_streams())

        @staticmethod
        def static_data_streams():
            """
            Some streams require use of a static data set, and should
            only be referenced in static tests.
            """
            return {
                'locations',  # Limit 300 objects, DELETES not supported
            }

        @staticmethod
        def untestable_streams():
            """STREAMS THAT CANNOT CURRENTLY BE TESTED"""
            return {
                'bank_accounts',  # No endpoints for CREATE or UPDATE
                'cash_drawer_shifts',  # Require cash transactions (not supported by API)
                'settlements',  # Depenedent on bank_account related transactions, no endpoints for CREATE or UPDATE
            }

        def dynamic_data_streams(self):
            """Expected streams minus streams with static data."""
            return self.expected_streams().difference(
                self.static_data_streams())

        def expected_streams(self):
            """A set of expected stream names"""
            return set(self.expected_metadata().keys())

        def expected_incremental_streams(self):
            return set(stream for stream, rep_meth in
                       self.expected_replication_method().items()
                       if rep_meth == self.INCREMENTAL)

        def expected_full_table_streams(self):
            return set(stream for stream, rep_meth in
                       self.expected_replication_method().items()
                       if rep_meth == self.FULL)

        @abstractmethod
        def testable_streams_dynamic(self):
            pass

        @abstractmethod
        def testable_streams_static(self):
            pass

        def testable_streams(self,
                             environment='sandbox',
                             data_type=DataType.DYNAMIC):
            allowed_environments = {self.SANDBOX, self.PRODUCTION}
            if environment not in allowed_environments:
                raise NotImplementedError(
                    "Environment can only be one of {}, but {} provided".
                    format(allowed_environments, environment))

            if environment == self.SANDBOX:
                if data_type == DataType.DYNAMIC:
                    return self.testable_streams_dynamic().intersection(
                        self.sandbox_streams())
                else:
                    return self.testable_streams_static().intersection(
                        self.sandbox_streams())
            else:
                if data_type == DataType.DYNAMIC:
                    return self.testable_streams_dynamic().intersection(
                        self.production_streams())
                else:
                    return self.testable_streams_static().intersection(
                        self.production_streams())

        def expected_primary_keys(self):
            """
            return a dictionary with key of table name
            and value as a set of primary key fields
            """
            return {
                table: properties.get(self.PRIMARY_KEYS, set())
                for table, properties in self.expected_metadata().items()
            }

        @staticmethod
        def makeshift_primary_keys():
            """
            return a dictionary with key of table name
            and value as a set of primary key fields
            """
            return {
                'inventories': {'catalog_object_id', 'location_id', 'state'}
            }

        def expected_replication_keys(self):
            incremental_streams = self.expected_incremental_streams()
            return {
                table: properties.get(self.REPLICATION_KEYS, set())
                for table, properties in self.expected_metadata().items()
                if table in incremental_streams
            }

        def expected_stream_to_start_date_key(self):
            return {
                table: properties.get(self.START_DATE_KEY)
                for table, properties in self.expected_metadata().items()
                if properties.get(self.START_DATE_KEY)
            }

        def expected_automatic_fields(self):
            auto_fields = {}
            for k, v in self.expected_metadata().items():
                auto_fields[k] = v.get(self.PRIMARY_KEYS, set()) | v.get(
                    self.REPLICATION_KEYS, set())
            return auto_fields

        def select_all_streams_and_fields(self,
                                          conn_id,
                                          catalogs,
                                          select_all_fields: bool = True,
                                          exclude_streams=None):
            """Select all streams and all fields within streams"""

            for catalog in catalogs:
                if exclude_streams and catalog.get(
                        'stream_name') in exclude_streams:
                    continue
                schema = menagerie.get_annotated_schema(
                    conn_id, catalog['stream_id'])
                non_selected_properties = []
                if not select_all_fields:
                    # get a list of all properties so that none are selected
                    non_selected_properties = schema.get(
                        'annotated-schema', {}).get('properties', {})
                    # remove properties that are automatic
                    for prop in self.expected_automatic_fields().get(
                            catalog['stream_name'], []):
                        if prop in non_selected_properties:
                            del non_selected_properties[prop]
                    non_selected_properties = non_selected_properties.keys()
                additional_md = []

                connections.select_catalog_and_fields_via_metadata(
                    conn_id,
                    catalog,
                    schema,
                    additional_md=additional_md,
                    non_selected_fields=non_selected_properties)

        @staticmethod
        def get_selected_fields_from_metadata(metadata):
            selected_fields = set()
            for field in metadata:
                is_field_metadata = len(field['breadcrumb']) > 1
                inclusion_automatic_or_selected = (
                    field['metadata']['inclusion'] == 'automatic'
                    or field['metadata']['selected'] is True)
                if is_field_metadata and inclusion_automatic_or_selected:
                    selected_fields.add(field['breadcrumb'][1])
            return selected_fields

        @staticmethod
        def _get_abs_path(path):
            return os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                path)

        def _load_schemas(self, stream):
            schemas = {}

            path = self._get_abs_path("schemas") + "/" + stream + ".json"
            final_path = path.replace('tests', 'tap_square')

            with open(final_path) as file:
                schemas[stream] = json.load(file)

            return schemas

        @staticmethod
        def parse_date(date_value):
            """
            Pass in string-formatted-datetime, parse the value, and return it as an unformatted datetime object.
            """
            try:
                date_stripped = dt.strptime(date_value,
                                            "%Y-%m-%dT%H:%M:%S.%fZ")
                return date_stripped
            except ValueError:
                try:
                    date_stripped = dt.strptime(date_value,
                                                "%Y-%m-%dT%H:%M:%SZ")
                    return date_stripped
                except ValueError:
                    raise NotImplementedError(
                        "We are not accounting for dates of this format: {}".
                        format(date_value))

        @staticmethod
        def date_check_and_parse(date_value):
            """
            Pass in any value and return that value. If the value is a string-formatted-datetime, parse
            the value and return it as an unformatted datetime object.
            """
            try:
                date_stripped = dt.strptime(date_value,
                                            "%Y-%m-%dT%H:%M:%S.%fZ")
                return date_stripped
            except ValueError:
                try:
                    date_stripped = dt.strptime(date_value,
                                                "%Y-%m-%dT%H:%M:%SZ")
                    return date_stripped
                except ValueError:
                    return date_value

        def expected_schema_keys(self, stream):
            props = self._load_schemas(stream).get(stream).get('properties')
            assert props, "{} schema not configured proprerly"

            return props.keys()

        def modify_expected_records(self, records):
            for rec in records:
                self.modify_expected_record(rec)

        def modify_expected_record(self, expected_record):
            """ Align expected data with how the tap _should_ emit them. """
            if isinstance(expected_record, dict):
                for key, value in expected_record.items(
                ):  # Modify a single record
                    if isinstance(value, dict):
                        self.modify_expected_record(value)
                    elif isinstance(value, list):
                        for item in value:
                            self.modify_expected_record(item)
                    else:
                        self.align_date_type(expected_record, key, value)
                        self.align_number_type(expected_record, key, value)

        def align_date_type(self, record, key, value):
            """datetime values must conform to ISO-8601 or they will be rejected by the gate"""
            if isinstance(value, str) and isinstance(
                    self.date_check_and_parse(value), dt):
                raw_date = self.date_check_and_parse(value)
                iso_date = dt.strftime(raw_date, "%Y-%m-%dT%H:%M:%S.%fZ")
                record[key] = iso_date

        @staticmethod
        def align_number_type(record, key, value):
            """float values must conform to json number formatting so we convert to Decimal"""
            if isinstance(value, float) and key in ['latitude', 'longitude']:
                record[key] = str(value)

        def create_test_data(self,
                             testable_streams,
                             start_date,
                             start_date_2=None,
                             min_required_num_records_per_stream=None,
                             force_create_records=False):
            if min_required_num_records_per_stream is None:
                min_required_num_records_per_stream = {
                    stream: 1
                    for stream in testable_streams
                }

            if not start_date_2:
                start_date_2 = start_date

            # Force modifier_lists to go first and payments to go last
            create_test_data_streams = list(testable_streams)
            create_test_data_streams = self._shift_to_start_of_list(
                'employees', create_test_data_streams)
            create_test_data_streams = self._shift_to_start_of_list(
                'modifier_lists', create_test_data_streams)
            # creating a refunds results in a new payment, putting it after ensures the number of orders is consistent
            create_test_data_streams = self._shift_to_end_of_list(
                'payments', create_test_data_streams)
            # creating a payment results in a new order, putting it after ensures the number of orders is consistent
            create_test_data_streams = self._shift_to_end_of_list(
                'orders', create_test_data_streams)
            # creating an inventory results in a new item, putting it after ensures the number of items is consistent
            create_test_data_streams = self._shift_to_end_of_list(
                'items', create_test_data_streams)

            stream_to_expected_records = {
                stream: []
                for stream in self.expected_streams()
            }

            for stream in create_test_data_streams:
                stream_to_expected_records[stream] = self.client.get_all(
                    stream, start_date)

                start_date_key = self.get_start_date_key(stream)
                if (not any([
                        stream_obj.get(start_date_key)
                        and self.parse_date(stream_obj.get(start_date_key)) >
                        self.parse_date(start_date_2)
                        for stream_obj in stream_to_expected_records[stream]
                ]) or len(stream_to_expected_records[stream]) <=
                        min_required_num_records_per_stream[stream]
                        or force_create_records):

                    num_records = max(
                        1, min_required_num_records_per_stream[stream] + 1 -
                        len(stream_to_expected_records[stream]))

                    LOGGER.info(
                        "Data missing for stream %s, will create %s record(s)",
                        stream, num_records)
                    # WORKAROUND to prevent more locations being created. We currently are at the max(300)
                    if stream != 'locations':

                        created_records = self.client.create(
                            stream,
                            start_date=start_date,
                            num_records=num_records)

                        if isinstance(created_records, dict):
                            stream_to_expected_records[stream].append(
                                created_records)
                        elif isinstance(created_records, list):
                            stream_to_expected_records[stream].extend(
                                created_records)
                            self.assertEqual(num_records, len(created_records))
                        else:
                            raise NotImplementedError(
                                "created_records unknown type: {}".format(
                                    created_records))

                print("Adjust expectations for stream: {}".format(stream))
                self.modify_expected_records(
                    stream_to_expected_records[stream])

            return stream_to_expected_records

        def get_start_date_key(self, stream):
            replication_type = self.expected_replication_method().get(stream)
            if replication_type == self.INCREMENTAL and self.expected_replication_keys(
            ).get(stream):
                start_date_key = next(
                    iter(self.expected_replication_keys().get(stream)))
            elif replication_type == self.FULL and self.expected_stream_to_start_date_key(
            ).get(stream):
                start_date_key = self.expected_stream_to_start_date_key().get(
                    stream)
            else:
                start_date_key = 'created_at'

            return start_date_key

        @staticmethod
        def _shift_to_start_of_list(key, values):
            new_list = values.copy()

            if key in values:
                new_list.remove(key)
                new_list.insert(0, key)

            return new_list

        @staticmethod
        def _shift_to_end_of_list(key, values):
            new_list = values.copy()
            if key in values:
                new_list.remove(key)
                new_list.append(key)

            return new_list

        def run_and_verify_check_mode(self, conn_id):
            """
            Run the tap in check mode and verify it succeeds.
            This should be ran prior to field selection and initial sync.

            Return the connection id and found catalogs from menagerie.
            """
            # run in check mode
            check_job_name = runner.run_check_mode(self, conn_id)

            # verify check exit codes
            exit_status = menagerie.get_exit_status(conn_id, check_job_name)
            menagerie.verify_check_exit_status(self, exit_status,
                                               check_job_name)

            found_catalogs = menagerie.get_catalogs(conn_id)
            self.assertGreater(
                len(found_catalogs),
                0,
                msg="unable to locate schemas for connection {}".format(
                    conn_id))

            found_catalog_names = set(
                map(lambda c: c['tap_stream_id'], found_catalogs))
            diff = self.expected_check_streams().symmetric_difference(
                found_catalog_names)
            self.assertEqual(
                len(diff),
                0,
                msg="discovered schemas do not match: {}".format(diff))
            print("discovered schemas are OK")

            return found_catalogs

        def perform_and_verify_table_and_field_selection(
                self,
                conn_id,
                found_catalogs,
                streams_to_select,
                select_all_fields=True):
            """
            Perform table and field selection based off of the streams to select set and field selection parameters.
            Verfify this results in the expected streams selected and all or no fields selected for those streams.
            """
            # Select all available fields or select no fields from all testable streams
            exclude_streams = self.expected_streams().difference(
                streams_to_select)
            self.select_all_streams_and_fields(
                conn_id=conn_id,
                catalogs=found_catalogs,
                select_all_fields=select_all_fields,
                exclude_streams=exclude_streams)

            catalogs = menagerie.get_catalogs(conn_id)

            # Ensure our selection worked
            for cat in catalogs:
                catalog_entry = menagerie.get_annotated_schema(
                    conn_id, cat['stream_id'])
                # Verify all testable streams are selected
                selected = catalog_entry.get('annotated-schema').get(
                    'selected')
                print("Validating selection on {}: {}".format(
                    cat['stream_name'], selected))
                if cat['stream_name'] not in streams_to_select:
                    self.assertFalse(selected,
                                     msg="Stream selected, but not testable.")
                    continue  # Skip remaining assertions if we aren't selecting this stream
                self.assertTrue(selected, msg="Stream not selected.")

                if select_all_fields:
                    # Verify all fields within each selected stream are selected
                    for field, field_props in catalog_entry.get(
                            'annotated-schema').get('properties').items():
                        field_selected = field_props.get('selected')
                        print("\tValidating selection on {}.{}: {}".format(
                            cat['stream_name'], field, field_selected))
                        self.assertTrue(field_selected,
                                        msg="Field not selected.")
                else:
                    # Verify only automatic fields are selected
                    expected_automatic_fields = self.expected_automatic_fields(
                    ).get(cat['tap_stream_id'])
                    selected_fields = self.get_selected_fields_from_metadata(
                        catalog_entry['metadata'])
                    self.assertEqual(expected_automatic_fields,
                                     selected_fields)

        def run_and_verify_sync(self, conn_id, clear_state=True):
            """
            Clear the connections state in menagerie and Run a Sync.
            Verify the exit code following the sync.

            Return the connection id and record count by stream
            """
            if clear_state:
                #clear state
                menagerie.set_state(conn_id, {})

            # run sync
            sync_job_name = runner.run_sync_mode(self, conn_id)

            # Verify tap exit codes
            exit_status = menagerie.get_exit_status(conn_id, sync_job_name)
            menagerie.verify_sync_exit_status(self, exit_status, sync_job_name)

            # read target output
            first_record_count_by_stream = runner.examine_target_output_file(
                self, conn_id, self.expected_streams(),
                self.expected_primary_keys())

            return first_record_count_by_stream

        def getPKsToRecordsDict(self, stream, records):
            """Return dict object of tupled pk values to record"""
            primary_keys = list(self.expected_primary_keys().get(
                stream)) if self.expected_primary_keys().get(
                    stream) else self.makeshift_primary_keys().get(stream)
            pks_to_record_dict = {
                tuple(record.get(pk) for pk in primary_keys): record
                for record in records
            }
            return pks_to_record_dict

        ##########################################################################
        ### Standard Assertion Patterns
        ##########################################################################

        def assertPKsEqual(self,
                           stream,
                           expected_records,
                           sync_records,
                           assert_pk_count_same=False):
            """
            Compare the values of the primary keys for expected and synced records.
            For this comparison to be valid we also check for duplicate primary keys.

            Parameters:
            arg1 (int): Description of arg1
            """
            primary_keys = list(self.expected_primary_keys().get(
                stream)) if self.expected_primary_keys().get(
                    stream) else self.makeshift_primary_keys().get(stream)

            # Verify there are no duplicate pks in the target
            sync_pks = [
                tuple(sync_record.get(pk) for pk in primary_keys)
                for sync_record in sync_records
            ]
            sync_pks_set = set(sync_pks)
            self.assertEqual(
                len(sync_pks),
                len(sync_pks_set),
                msg="A duplicate record may have been replicated.")

            # Verify there are no duplicate pks in our expectations
            expected_pks = [
                tuple(expected_record.get(pk) for pk in primary_keys)
                for expected_record in expected_records
            ]
            expected_pks_set = set(expected_pks)
            self.assertEqual(
                len(expected_pks),
                len(expected_pks_set),
                msg="Our expectations contain a duplicate record.")

            # Verify sync pks have all expected records pks in it
            self.assertTrue(sync_pks_set.issuperset(expected_pks_set))

            if assert_pk_count_same:
                self.assertEqual(expected_pks_set, sync_pks_set)

        def assertParentKeysEqual(self, expected_record, sync_record):
            """Compare the top level keys of an expected record and a sync record."""
            self.assertEqual(frozenset(expected_record.keys()), frozenset(sync_record.keys()),
                             msg="Expected keys in expected_record to equal keys in sync_record. " +\
                             "[expected_record={}][sync_record={}]".format(expected_record, sync_record))

        ##########################################################################
        ### Tap Specific Assertions
        ##########################################################################

        def assertRecordsEqual(self, stream, expected_record, sync_record):
            """
            Certain Square streams cannot be compared directly with assertDictEqual().
            So we handle that logic here.
            """
            if stream == 'payments':
                self.assertDictEqualWithOffKeys(expected_record, sync_record,
                                                {'updated_at'})
            elif stream in {'employees', 'roles'}:
                self.assertDictEqualWithOffKeys(expected_record, sync_record,
                                                {'created_at', 'updated_at'})
            elif stream == 'inventories':
                self.assertDictEqualWithOffKeys(expected_record, sync_record,
                                                {'calculated_at'})
            elif stream == 'items':
                self.assertParentKeysEqual(expected_record, sync_record)
                expected_record_copy = deepcopy(expected_record)
                sync_record_copy = deepcopy(sync_record)

                # Square api for some reason adds legacy_tax_ids in item_data but not when the item is created. If they are equal to tax_ids (which we compare with the expected record correctly) they're ignored if they are missing only in the expected record
                if ('item_data' in expected_record
                        and 'item_data' in sync_record
                        and 'legacy_tax_ids' in sync_record['item_data'] and
                        'legacy_tax_ids' not in expected_record['item_data']):
                    self.assertIn('tax_ids', sync_record['item_data'])
                    self.assertEqual(
                        sync_record_copy['item_data'].pop('legacy_tax_ids'),
                        sync_record['item_data']['tax_ids'])

                self.assertDictEqual(expected_record_copy, sync_record_copy)
            else:
                self.assertDictEqual(expected_record, sync_record)

        def assertDictEqualWithOffKeys(self,
                                       expected_record,
                                       sync_record,
                                       off_keys=frozenset()):
            self.assertParentKeysEqual(expected_record, sync_record)
            expected_record_copy = deepcopy(expected_record)
            sync_record_copy = deepcopy(sync_record)

            # Square api workflow updates these values so they're a few seconds different between
            # the time the record is created and the tap syncs, but other fields are the same
            for off_key in off_keys:
                self.assertGreaterEqual(sync_record_copy.pop(off_key),
                                        expected_record_copy.pop(off_key))
            self.assertDictEqual(expected_record_copy, sync_record_copy)
Ejemplo n.º 20
0
"""
This is used for testing basic functionality of the test client.
To run change the desired flags below and use the following command from the tap-tester repo:
  'python ../tap-adroll/tests/client_tester.py'
"""
from test_client import TestClient

##########################################################################
# Testing the TestCLient
##########################################################################
if __name__ == "__main__":
    client = TestClient()

    # CHANGE FLAGS HERE TO TEST SPECIFIC FUNCTION TYPES
    test_creates = False
    test_updates = True
    test_gets = False
    test_deletes = False

    # CHANGE FLAG TO PRINT ALL OBJECTS THAT FUNCTIONS INTERACT WITH
    print_objects = True

    objects_to_test = [  # CHANGE TO TEST DESIRED STREAMS 
    ]
    # 'ads', # GET - DONE | CREATE - DONE | UPDATE - DONE
    # 'ad_groups', # GET - DONE | CREATE - DONE | UPDATE - DONE
    # 'segments', # GET - DONE | CREATE - DONE | UPDATE - DONE
    # 'campaigns', # GET - DONE | CREATE - DONE | UPDATES - DONE
    # 'advertisables', # GET - DONE | CREATE NA (DONT DO THIS ONE)
    # 'ad_reports', GET - DONE | CREATE - N/A
Ejemplo n.º 21
0
 def setUpClass(cls):
     cls.tc = TestClient()
Ejemplo n.º 22
0
class TestProposalsAPI(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()

        # Creating the 1st user. This user will issue a request
        user = User(username=self.default_username)
        user.set_password_hash(self.default_password)
        db.session.add(user)
        db.session.commit()
        self.client = TestClient(self.app, user.generate_auth_token(), '')

        # Create a request for 1st user
        request_data = {
            'meal_type': 'Vietnamese',
            'meal_time': 'Dinner',
            'location_string': 'San Francisco'
        }
        rv, json = self.client.post(API_VERSION + '/requests/',
                                    data=request_data)
        self.request_location = rv.headers['Location']

        # Create the 2nd user. This user will make proposal for the request by
        # 1st user
        user_2 = User(username='******')
        user_2.set_password_hash('123456')
        db.session.add(user_2)
        db.session.commit()
        self.client = TestClient(self.app, user_2.generate_auth_token(), '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_misc(self):
        root_endpoint = API_VERSION + '/'
        rv, json = self.client.get(root_endpoint)
        self.assertTrue(rv.status_code == 200)

    def test_proposals(self):
        proposals_endpoint = API_VERSION + '/proposals/'
        rv, json = self.client.get(proposals_endpoint)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(len(json['proposals']) == 0)

        # Add new proposal
        rv, json = self.client.get(self.request_location)
        self.assertTrue(rv.status_code == 200)

        proposals_url = json['proposals_url']
        rv, json = self.client.post(proposals_url, data={'a': 'b'})
        self.assertTrue(rv.status_code == 201)

        proposal_location = rv.headers['Location']
        rv, json = self.client.get(proposal_location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['request_url'] == self.request_location)
        self.assertTrue(json['self_url'] == proposal_location)
        self.assertTrue(json['user_url'].split('/')[-1] == str(g.user.id))

        rv, json = self.client.get(proposals_url)
        self.assertTrue(rv.status_code == 200)
        self.assertIn(proposal_location, json['proposals'])

        # Update proposal
        rv, json = self.client.put(proposal_location, data={'accepted': True})
        self.assertTrue(rv.status_code == 200)

        rv, json = self.client.get(proposal_location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['accepted'] == True)

        # Delete proposal
        rv, json = self.client.delete(proposal_location)
        self.assertTrue(rv.status_code == 200)

        with self.assertRaises(NotFound):
            rv, json = self.client.get(proposal_location)
Ejemplo n.º 23
0
class TestUsersAPI(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()
        user = User(username=self.default_username)
        user.set_password_hash(self.default_password)
        db.session.add(user)
        db.session.commit()
        self.client = TestClient(self.app, user.generate_auth_token(), '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_misc(self):
        root_endpoint = API_VERSION + '/'
        rv, json = self.client.get(root_endpoint)
        self.assertTrue(rv.status_code == 200)

    def test_users(self):
        # Get list of all users
        users_endpoint = API_VERSION + '/users/'
        rv, json = self.client.get(users_endpoint)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(len(json['users']) == 1)

        # Add a new user
        user_data = {
            'username': '******',
            'password': '******',
            'email': '*****@*****.**',
            'photo_url': 'http://www.gotitapp.co/user.jpg'
        }

        rv, json = self.client.post(users_endpoint, data=user_data)
        self.assertTrue(rv.status_code == 201)
        location = rv.headers['Location']
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['username'] == user_data['username'])
        self.assertTrue(json['email'] == user_data['email'])
        self.assertTrue(json['photo_url'] == user_data['photo_url'])

        # Add a duplicated user
        duplicated_user_data = {
            'username': '******',
            'password': '******',
            'email': '*****@*****.**',
            'photo_url': 'http://www.gotitapp.co/user_2.jpg'
        }
        with self.assertRaises(ValidationError) as context:
            rv, json = self.client.post(users_endpoint,
                                        data=duplicated_user_data)
            self.assertTrue('Username already existed!' in context.exception)

        # Edit a user
        updated_user_data = {
            'username': '******',
            'email': '*****@*****.**',
            'photo_url': 'http://www.gotitapp.co/user2.jpg'
        }

        rv, json = self.client.put(location, data=updated_user_data)
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['username'] == updated_user_data['username'])
        self.assertTrue(json['email'] == updated_user_data['email'])
        self.assertTrue(json['photo_url'] == updated_user_data['photo_url'])

        # Delete a user
        rv, json = self.client.delete(location)
        self.assertTrue(rv.status_code == 200)

        with self.assertRaises(NotFound) as context:
            rv, json = self.client.get(location)
            self.assertTrue('Not Found' in context.exception)
Ejemplo n.º 24
0
 def __init__(self, num_length):
     self.test_client = TestClient()
     #self.data_array = np.zeros(num_length)
     #self.data_array = np.zeros(num_length)
     self.data_array = [None] * num_length
     self.num_length = num_length
 def setUpClass(cls):
     print("\n\nTEST SETUP\n")
     cls.client = TestClient()
Ejemplo n.º 26
0
 def setUp(self):
     self.c = TestClient()
     self.count = Subscription.objects.count
Ejemplo n.º 27
0
class TestAPI(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()
        u = User(username=self.default_username)
        u.set_password(self.default_password)
        db.session.add(u)
        db.session.commit()
        self.client = TestClient(self.app, u.generate_auth_token(), '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_customers(self):
        # get list of customers
        rv, json = self.client.get('/customers/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['customers'] == [])

        # add a customer
        rv, json = self.client.post('/customers/', data={'name': 'john'})
        self.assertTrue(rv.status_code == 201)
        location = rv.headers['Location']
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'john')
        rv, json = self.client.get('/customers/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['customers'] == [location])

        # edit the customer
        rv, json = self.client.put(location, data={'name': 'John Smith'})
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'John Smith')

    def test_products(self):
        # get list of products
        rv, json = self.client.get('/products/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['products'] == [])

        # add a customer
        rv, json = self.client.post('/products/', data={'name': 'prod1'})
        self.assertTrue(rv.status_code == 201)
        location = rv.headers['Location']
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'prod1')
        rv, json = self.client.get('/products/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['products'] == [location])

        # edit the customer
        rv, json = self.client.put(location, data={'name': 'product1'})
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['name'] == 'product1')

    def test_orders_and_items(self):
        # define a customer
        rv, json = self.client.post('/customers/', data={'name': 'john'})
        self.assertTrue(rv.status_code == 201)
        customer = rv.headers['Location']
        rv, json = self.client.get(customer)
        orders_url = json['orders_url']
        rv, json = self.client.get(orders_url)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['orders'] == [])

        # define two products
        rv, json = self.client.post('/products/', data={'name': 'prod1'})
        self.assertTrue(rv.status_code == 201)
        prod1 = rv.headers['Location']
        rv, json = self.client.post('/products/', data={'name': 'prod2'})
        self.assertTrue(rv.status_code == 201)
        prod2 = rv.headers['Location']

        # create an order
        rv, json = self.client.post(orders_url,
                                    data={'date': '2014-01-01T00:00:00Z'})
        self.assertTrue(rv.status_code == 201)
        order = rv.headers['Location']
        rv, json = self.client.get(order)
        items_url = json['items_url']
        rv, json = self.client.get(items_url)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['items'] == [])
        rv, json = self.client.get('/orders/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(len(json['orders']) == 1)
        self.assertTrue(order in json['orders'])

        # edit the order
        rv, json = self.client.put(order,
                                   data={'date': '2014-02-02T00:00:00Z'})
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(order)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['date'] == '2014-02-02T00:00:00Z')

        # add two items to order
        rv, json = self.client.post(items_url,
                                    data={
                                        'product_url': prod1,
                                        'quantity': 2
                                    })
        self.assertTrue(rv.status_code == 201)
        item1 = rv.headers['Location']
        rv, json = self.client.post(items_url,
                                    data={
                                        'product_url': prod2,
                                        'quantity': 1
                                    })
        self.assertTrue(rv.status_code == 201)
        item2 = rv.headers['Location']
        rv, json = self.client.get(items_url)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(len(json['items']) == 2)
        self.assertTrue(item1 in json['items'])
        self.assertTrue(item2 in json['items'])
        rv, json = self.client.get(item1)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['product_url'] == prod1)
        self.assertTrue(json['quantity'] == 2)
        self.assertTrue(json['order_url'] == order)
        rv, json = self.client.get(item2)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['product_url'] == prod2)
        self.assertTrue(json['quantity'] == 1)
        self.assertTrue(json['order_url'] == order)

        # edit the second item
        rv, json = self.client.put(item2,
                                   data={
                                       'product_url': prod2,
                                       'quantity': 3
                                   })
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(item2)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['product_url'] == prod2)
        self.assertTrue(json['quantity'] == 3)
        self.assertTrue(json['order_url'] == order)

        # delete first item
        rv, json = self.client.delete(item1)
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client.get(items_url)
        self.assertFalse(item1 in json['items'])
        self.assertTrue(item2 in json['items'])

        # delete order
        rv, json = self.client.delete(order)
        self.assertTrue(rv.status_code == 200)
        with self.assertRaises(NotFound):
            rv, json = self.client.get(item2)
        rv, json = self.client.get('/orders/')
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(len(json['orders']) == 0)
Ejemplo n.º 28
0
class APITestCase(unittest.TestCase):
    default_username = '******'
    default_password = '******'

    def setUp(self):
        self.app = create_app('testing')
        self.app_context = self.app.app_context()
        self.app_context.push()
        db.create_all()
        u = User(username=self.default_username,
                 password=self.default_password)
        db.session.add(u)
        db.session.commit()

        self.client = TestClient(self.app, u.generate_auth_token(), '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.app_context.pop()

    def test_404(self):
        response, json_response = self.client.get('/wrong/url')
        self.assertTrue(response.status_code == 404)
        self.assertTrue(json_response['error'] == 'not found')

    def test_task(self):
        test_task = {'title': 'Smash Patriarchy'}

        #create task
        response, json_response = self.client.post(url_for('api.create_task'),
                                                   data=test_task)
        self.assertEquals(response.status_code, 201)
        self.assertTrue(json_response['title'], test_task['title'])
        inputted_id = json_response['id']

        #get task
        response, json_response = self.client.get(
            url_for('api.get_task', task_id=inputted_id, _external=True))
        self.assertTrue(json_response['title'], test_task['title'])

        #update task
        response, json_response = self.client.put(url_for('api.update_task',
                                                          task_id=inputted_id),
                                                  data={'done': True})
        self.assertEquals(response.status_code, 201)
        self.assertTrue(json_response['done'], True)

        #delete task
        response, json_response = self.client.delete(
            url_for('api.delete_task', task_id=inputted_id))
        self.assertEquals(response.status_code, 201)

    def test_user(self):
        test_user = {'username': '******'}

        #create user
        response, json_response = self.client.post(url_for('api.create_user'),
                                                   data=test_user)
        self.assertEquals(response.status_code, 201)
        self.assertTrue(json_response['username'], test_user['username'])
        inputted_id = json_response['id']

        #get user
        respose, json_response = self.client.get(
            url_for('api.get_user', user_id=inputted_id, _external=True))
        self.assertTrue(json_response['username'], test_user['username'])

        #delete user
        response, json_response = self.client.delete(
            url_for('api.delete_user', user_id=inputted_id))
        self.assertEquals(response.status_code, 201)

    def test_relations(self):
        user2_json = {'username': '******', 'password': '******'}

        user2 = User.from_json(user2_json)
        db.session.add(user2)
        db.session.commit()

        user2_tasks = [{
            'title': 'destroy the state'
        }, {
            'title': 'form cooperatives'
        }]
        user1_tasks = [{
            'title': 'form a proletarian dictatorship'
        }, {
            'title': 'implement the 5 year plan'
        }, {
            'title': 'decide on a successor'
        }]

        # we put the tasks 'manually' in order to not login in and out of user1
        for task_json in user2_tasks:
            task = Task.from_json(task_json)
            task.user_id = user2.id
            db.session.add(task)
        db.session.commit()

        # create task of user1
        for task in user1_tasks:
            response, json_response = self.client.post(
                url_for('api.create_task'), data=task)

        # get tasks of user1

        response, json_response = self.client.get(url_for('api.get_tasks'))
        self.assertEquals(response.status_code, 200)

        # check if only the task of the user is given
        self.assertEquals(len(json_response['tasks']), 3)

        response_titles = [task['title'] for task in json_response['tasks']]
        input_titles = [task['title'] for task in user1_tasks]

        self.assertEquals(set(input_titles), set(response_titles))

        # check if access to the tasks of user2 allowed
        user2_task = user2.tasks.first()

        response, json_response = self.client.get(
            url_for('api.get_task', task_id=user2_task.id))
        self.assertEquals(response.status_code, 403)
Ejemplo n.º 29
0
class TestDatesAPI(unittest.TestCase):
    def setUp(self):
        self.app = app
        self.ctx = self.app.app_context()
        self.ctx.push()
        db.drop_all()
        db.create_all()

        # Creating the 1st user. This user will issue a request
        requester = User(username='******')
        requester.set_password_hash('password_1')
        db.session.add(requester)
        db.session.commit()
        self.requester = requester
        self.client_1 = TestClient(self.app, requester.generate_auth_token(),
                                   '')

        # Create a request for 1st user
        request_data = {
            'meal_type': 'Vietnamese',
            'meal_time': 'Dinner',
            'location_string': 'San Francisco'
        }
        rv, json = self.client_1.post(API_VERSION + '/requests/',
                                      data=request_data)
        self.request_location = rv.headers['Location']

        # Create the 2nd user. This user will make proposal for the request by
        # 1st user
        proposer = User(username='******')
        proposer.set_password_hash('password_1')
        db.session.add(proposer)
        db.session.commit()
        self.client_2 = TestClient(self.app, proposer.generate_auth_token(),
                                   '')

    def tearDown(self):
        db.session.remove()
        db.drop_all()
        self.ctx.pop()

    def test_misc(self):
        root_endpoint = API_VERSION + '/'
        rv, json = self.client_1.get(root_endpoint)
        self.assertTrue(rv.status_code == 200)

    def test_dates(self):
        proposals_endpoint = API_VERSION + '/dates/'
        rv, json = self.client_1.get(proposals_endpoint)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(len(json['dates']) == 0)

        # Proposer add a proposal for the request by Requester
        rv, json = self.client_2.get(self.request_location)
        self.assertTrue(rv.status_code == 200)

        proposals_url = json['proposals_url']
        rv, json = self.client_2.post(proposals_url, data={'foo': 'bar'})
        self.assertTrue(rv.status_code == 201)
        proposal_location = rv.headers['Location']
        rv, json = self.client_2.get(proposal_location)
        date_url = json['date_url']
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['request_url'] == self.request_location)
        self.assertTrue(json['self_url'] == proposal_location)
        self.assertTrue(json['user_url'].split('/')[-1] == str(g.user.id))

        # Requester accept a proposal
        rv, json = self.client_1.post(date_url, data={'foo': 'bar'})
        self.assertTrue(rv.status_code == 201)
        date_location = rv.headers['Location']
        rv, json = self.client_1.get(date_location)
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client_2.get(date_location)
        self.assertTrue(rv.status_code == 200)

        # Requester update a date
        rv, json = self.client_1.put(date_location,
                                     data={'restaurant_name': 'Japanese'})
        self.assertTrue(rv.status_code == 200)
        rv, json = self.client_1.get(date_location)
        self.assertTrue(rv.status_code == 200)
        self.assertTrue(json['restaurant_name'] == 'Japanese')

        # Requester delete a date
        rv, json = self.client_1.delete(date_location)
        self.assertTrue(rv.status_code == 200)
        with self.assertRaises(NotFound):
            rv, json = self.client_1.get(date_location)