Tuesday, October 03, 2017

mocking AWS Lambdas

When dealing with AWS services it can be tedious mocking the endpoints.  Luckily there's a wonderful module called moto which takes care of this for you.  moto supports a majority of the AWS backends in various degrees of completeness. Recently I overhauled the lambda backend and added support for running lambdas in the environment they're specified to run in (JS, python, etc) as well as linking to SNS events and cloudwatch logging.

In my use-case I had:
  1. My test-cases running on bare-metal, including:
    1. Mocking sns, lambda, s3, kms, logs, and cloudwatch endpoints via moto.
    2. Registering an AWS lambda that connects to a mocked SNS endpoint
    3. Mocked google endpoints via custom aiohttp server
  2. docker container which forwarded messages from a Google PubSub endpoint (mocked via PubSub Emulator) to the SNS moto mocked endpoint which triggered the mocked lambda.  
  3. Another container that registered subscriptions from mocked google services to PubSub endpoint, along with occasionally triggering lambda via SNS endpoint

Each moto mock endpoint was created via my helper class:

class MotoService:
    """ Will Create MotoService.
    
    Service is ref-counted so there will only be one per process. Real Service will
    be returned by `__aenter__`."""

    _services: Dict[str, Any] = dict()  # {name: instance}

    def __init__(self, service_name: str, port: int=None):
        self._service_name = service_name

        if port:
            self._socket = None
            self._port = port
        else:
            self._socket, self._port = get_free_tcp_port()

        self._thread = None
        self._logger = logging.getLogger('MotoService')
        self._refcount = None
        self._ip_address = get_ip_address()

    @property
    def endpoint_url(self):
        return 'http://{}:{}'.format(self._ip_address, self._port)

    def __call__(self, func):
        async def wrapper(*args, **kwargs):
            await self._start()
            try:
                result = await func(*args, **kwargs)
            finally:
                await self._stop()
            return result

        functools.update_wrapper(wrapper, func)
        wrapper.__wrapped__ = func
        return wrapper

    async def __aenter__(self):
        svc = self._services.get(self._service_name)
        if svc is None:
            self._services[self._service_name] = self
            self._refcount = 1
            await self._start()
            return self
        else:
            svc._refcount += 1
            return svc

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self._refcount -= 1

        if self._socket:
            self._socket.close()
            self._socket = None

        if self._refcount == 0:
            del self._services[self._service_name]
            await self._stop()

    @staticmethod
    def _shutdown():
        req = flask.request
        shutdown = req.environ['werkzeug.server.shutdown']
        shutdown()
        return flask.make_response('done', 200)

    def _create_backend_app(self, *args, **kwargs):
        backend_app = moto.server.create_backend_app(*args, **kwargs)
        backend_app.add_url_rule('/shutdown', 'shutdown', self._shutdown)
        return backend_app

    def _server_entry(self):
        self._main_app = moto.server.DomainDispatcherApplication(self._create_backend_app, service=self._service_name)
        self._main_app.debug = True

        if self._socket:
            self._socket.close()  # release right before we use it
            self._socket = None

        moto.server.run_simple(self._ip_address, self._port, self._main_app, threaded=True)

    async def _start(self):
        self._thread = threading.Thread(target=self._server_entry, daemon=True)
        self._thread.start()

        async with aiohttp.ClientSession() as session:
            for i in range(0, 10):
                if not self._thread.is_alive():
                    break

                try:
                    # we need to bypass the proxies due to monkeypatches
                    async with session.get(self.endpoint_url + '/static/', timeout=0.5):
                        pass
                    break
                except (asyncio.TimeoutError, aiohttp.ClientConnectionError):
                    await asyncio.sleep(0.5)
            else:
                await self._stop()  # pytest.fail doesn't call stop_process
                raise Exception("Can not start service: {}".format(self._service_name))

    async def _stop(self):
        try:
            async with aiohttp.ClientSession() as session:
                async with session.get(self.endpoint_url + '/shutdown', timeout=5):
                    pass
        except:
            self._logger.exception("Error stopping moto service")
            raise
        finally:
            self._thread.join()

My setUpClass looked something like the following:


@classmethod
    def setUpClass(cls):
        cls._pubsub_port = get_free_tcp_port(True)
        cls._gcloud_enumlator = subprocess.Popen(["gcloud", "beta", "emulators", "pubsub", "start", "--host-port={}:{}".format(IP_ADDRESS, cls._pubsub_port)], preexec_fn=os.setsid)

        boto_service_names = {'sns', 'lambda', 's3', 'kms', 'logs', 'cloudwatch'}
        cls._boto_svcs = {}

        async def start_svc(svc_name):
            cls._boto_svcs[svc_name] = await MotoService(svc_name).__aenter__()

        try:
            loop = asyncio.get_event_loop()
            loop.run_until_complete(asyncio.gather(*[start_svc(svc_name) for svc_name in boto_service_names]))

            cls._mock_env_vars = {'{}_mock_endpoint_url'.format(name): svc.endpoint_url + '/' for name, svc in cls._boto_svcs.items()}
            cls._mock_env_vars['PUBSUB_EMULATOR_HOST'] = '{}:{}'.format(IP_ADDRESS, cls._pubsub_port)
            cls._mock_env_vars['AWS_DEFAULT_REGION'] = AWS_DEFAULT_REGION

            for name, value in cls._mock_env_vars.items():
                os.environ[name] = value

            session = botocore.session.get_session()
            cls._boto_clients = {svc_name: session.create_client(svc_name) for svc_name in boto_service_names}
        except:
            cls.tearDownClass()
            raise

After which things like S3/KMS were set up.  One of the more interesting ones was the lambda function which connected to SNS which looked like this:


with open(os.path.join(CURRENT_DIR, '..', 'lambda_image', 'lambda_labeler_image.zip'), 'rb') as zip_file:
            lambda_response = self._boto_clients['lambda'].create_function(
                FunctionName=LAMBDA_FUNCTION_NAME, Runtime='python3.6',
                Role='test-iam-role', Handler='lambda_function.lambda_handler',
                Timeout=15, MemorySize=128, Publish=True,
                Code={'ZipFile': zip_file.read()},
                Environment={
                    'Variables': {**mock_env_vars, 'LOG_LEVEL': str(logging.DEBUG), 'UNITTEST': 'true'}
                })

        # now subscribe lambda function to SNS topic
        self._boto_clients['sns'].subscribe(TopicArn=self._sns_topic_arn, Protocol='lambda', Endpoint=lambda_response['FunctionArn'])

Note I forward the *_mock_endpoint_urls via environment variables.

I linked all the containers via a docker compose file that had something like the following:


environment:
      - AWS_ACCESS_KEY_ID=dummy
      - AWS_SECRET_ACCESS_KEY=dummy
      - sns_mock_endpoint_url=${sns_mock_endpoint_url}
      - lambda_mock_endpoint_url=${lambda_mock_endpoint_url}
      - s3_mock_endpoint_url=${s3_mock_endpoint_url}
      - kms_mock_endpoint_url=${kms_mock_endpoint_url}
      - logs_mock_endpoint_url=${logs_mock_endpoint_url}
      - google_mock_endpoint_url=${google_mock_endpoint_url}
      - PUBSUB_EMULATOR_HOST=${PUBSUB_EMULATOR_HOST}
      - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}


Now each container (including the lambda) ran the following to enable API call forwarding to the mocked endpoint:


def _wrapt_boto_create_client(wrapped, instance, args, kwargs):
    def unwrap_args(service_name, region_name=None, api_version=None,
                    use_ssl=True, verify=None, endpoint_url=None,
                    aws_access_key_id=None, aws_secret_access_key=None,
                    aws_session_token=None, config=None):

        if endpoint_url is None:
            endpoint_url = os.environ.get('{}_mock_endpoint_url'.format(service_name))

        return wrapped(service_name, region_name, api_version, use_ssl, verify,
                       endpoint_url, aws_access_key_id, aws_secret_access_key,
                       aws_session_token, config)

    return unwrap_args(*args, **kwargs)


def patch_boto():
    """
    Will patch botocore to set endpoint_url to: {SERVICE_NAME}_endpoint_url if
    available
    """
    wrapt.wrap_function_wrapper(
        'botocore.session',
        'Session.create_client',
        _wrapt_boto_create_client
    )


_redir_prefix = {
    'https://www.googleapis.com/',
    'https://accounts.google.com/',
    'https://people.googleapis.com/'
}


def _replace_url_prefix(url: str, redir_endpoint: str):
    if not redir_endpoint:
        return url

    if url.startswith(redir_endpoint):
        return url

    for prefix in _redir_prefix:
        if url.startswith(prefix):
            url = url.replace(prefix, redir_endpoint)
            return url

    assert False


def _wrapped_discovery_resource_init(wrapped, instance, args, kwargs):
    redir_endpoint = os.environ.get('google_mock_endpoint_url')

    def unwrap_args(http, baseUrl, model, requestBuilder, developerKey,
               resourceDesc, rootDesc, schema):

        baseUrl = _replace_url_prefix(baseUrl, redir_endpoint)

        return wrapped(http, baseUrl, model, requestBuilder, developerKey,
               resourceDesc, rootDesc, schema)

    return unwrap_args(*args, **kwargs)


def _wrapped_oath2_credentials_init(wrapped, instance, args, kwargs):
    redir_endpoint = os.environ.get('google_mock_endpoint_url')

    def unwrap_args(access_token, client_id, client_secret, refresh_token,
                 token_expiry, token_uri, user_agent, revoke_uri=None,
                 id_token=None, token_response=None, scopes=None,
                 token_info_uri=None):

        revoke_uri = _replace_url_prefix(revoke_uri, redir_endpoint)
        token_uri = _replace_url_prefix(token_uri, redir_endpoint)

        return wrapped(access_token, client_id, client_secret, refresh_token, token_expiry,
                       token_uri, user_agent, revoke_uri, id_token, token_response, scopes, token_info_uri)

    return unwrap_args(*args, **kwargs)


def patch_google_client():
    wrapt.wrap_function_wrapper(
        'googleapiclient.discovery',
        'Resource.__init__',
        _wrapped_discovery_resource_init
    )

    wrapt.wrap_function_wrapper(
        'oauth2client.client',
        'OAuth2Credentials.__init__',
        _wrapped_oath2_credentials_init
    )


# Run in start-up code
test_mock_endpoints = {name: value for name, value in os.environ.items() if name.endswith("_mock_endpoint_url")}
if test_mock_endpoints:
    patch_boto()

test_google_endpoint_url = os.environ.get('google_mock_endpoint_url')
if test_google_endpoint_url:
    patch_google_client()

Along with several helpers, my test-case would look something like this:


user_svc = self._google_svcs.dir_svc.users[ADMIN_EMAIL]

        # personal address
        headers = {
            'Date': email.utils.format_datetime(datetime.utcnow().replace(tzinfo=timezone.utc), True),
            'From': "{}".format(personal_email_addr),
            'To': '{} <{}>'.format(user_svc.user_obj['primaryEmail'], user_svc.user_obj['name']['fullName']),
            'Subject': "Dummy",
        }

        archive_msg_obj = user_svc.add_message([ARCHIVE_LABEL_NAME, 'FOLDER'], headers, 'dummy message')

        await self._wait_for_log_entry("Finished processing messages for user: {}".format(ADMIN_EMAIL))
        for msg in user_svc.messages.values():
            lbl_names = user_svc.get_lbl_names(msg['labelIds'])
            self.assertEqual(set(lbl_names), {EXPECTED_LABEL_NAME, 'FOLDER'})



This way each container and lambda invocation will forward AWS + Google client API calls to the mocked endpoints.  The end result is that by hitting a mocked google endpoint, a message would get published to the mocked google PubSub endpoint, triggering the message to get forwarded to a moto mocked SNS endpoint, triggering a lambda, which would log to cloudwatch, and be picked up by my test-case.  If others are interested in the google endpoint mock I can amend this post with that information as well.  All in all I'm really happy with the workflow as it duplicates running in production almost exactly with minimal changes to the production code, helping ensure no issues pop up when pushed to production.

1 comment:

Alex Mohr said...

I've moved the MotoService code to gist: https://gist.github.com/thehesiod/2e4094a1db1190f7e122e7043f1973a0