Browse Source

Support handle exception for aiohttp

Bastien Sevajol 5 years ago
parent
commit
c8ff050448
3 changed files with 95 additions and 22 deletions
  1. 57 16
      hapic/decorator.py
  2. 16 6
      hapic/hapic.py
  3. 22 0
      tests/ext/unit/test_aiohttp.py

+ 57 - 16
hapic/decorator.py View File

@@ -564,24 +564,65 @@ class ExceptionHandlerControllerWrapper(ControllerWrapper):
564 564
                 func_kwargs,
565 565
             )
566 566
         except self.handled_exception_class as exc:
567
-            response_content = self.error_builder.build_from_exception(
568
-                exc,
569
-                include_traceback=self.context.is_debug(),
570
-            )
567
+            return self._build_error_response(exc)
568
+
569
+    def _build_error_response(self, exc: Exception) -> typing.Any:
570
+        response_content = self.error_builder.build_from_exception(
571
+            exc,
572
+            include_traceback=self.context.is_debug(),
573
+        )
571 574
 
572
-            # Check error format
573
-            dumped = self.error_builder.dump(response_content).data
574
-            unmarshall = self.error_builder.load(dumped)
575
-            if unmarshall.errors:
576
-                raise OutputValidationException(
577
-                    'Validation error during dump of error response: {}'
575
+        # Check error format
576
+        dumped = self.error_builder.dump(response_content).data
577
+        unmarshall = self.error_builder.load(dumped)
578
+        if unmarshall.errors:
579
+            raise OutputValidationException(
580
+                'Validation error during dump of error response: {}'
578 581
                     .format(
579
-                        str(unmarshall.errors)
580
-                    )
582
+                    str(unmarshall.errors)
581 583
                 )
584
+            )
585
+
586
+        error_response = self.context.get_response(
587
+            json.dumps(dumped),
588
+            self.http_code,
589
+        )
590
+        return error_response
591
+
582 592
 
583
-            error_response = self.context.get_response(
584
-                json.dumps(dumped),
585
-                self.http_code,
593
+# TODO BS 2018-07-23: This class is an async version of
594
+# ExceptionHandlerControllerWrapper
595
+# to permit async compatibility. Please re-think about code refact
596
+# TAG: REFACT_ASYNC
597
+class AsyncExceptionHandlerControllerWrapper(ExceptionHandlerControllerWrapper):
598
+    def get_wrapper(
599
+        self,
600
+        func: 'typing.Callable[..., typing.Any]',
601
+    ) -> 'typing.Callable[..., typing.Any]':
602
+        # async def wrapper(*args, **kwargs) -> typing.Any:
603
+        async def wrapper(*args, **kwargs) -> typing.Any:
604
+            # Note: Design of before_wrapped_func can be to update kwargs
605
+            # by reference here
606
+            replacement_response = self.before_wrapped_func(args, kwargs)
607
+            if replacement_response is not None:
608
+                return replacement_response
609
+
610
+            response = await self._execute_wrapped_function(func, args, kwargs)
611
+            new_response = self.after_wrapped_function(response)
612
+            return new_response
613
+        return functools.update_wrapper(wrapper, func)
614
+
615
+    async def _execute_wrapped_function(
616
+        self,
617
+        func,
618
+        func_args,
619
+        func_kwargs,
620
+    ) -> typing.Any:
621
+        try:
622
+            return await super()._execute_wrapped_function(
623
+                func,
624
+                func_args,
625
+                func_kwargs,
586 626
             )
587
-            return error_response
627
+        except self.handled_exception_class as exc:
628
+            return self._build_error_response(exc)

+ 16 - 6
hapic/hapic.py View File

@@ -14,6 +14,7 @@ from hapic.decorator import DecoratedController
14 14
 from hapic.decorator import DECORATION_ATTRIBUTE_NAME
15 15
 from hapic.decorator import ControllerReference
16 16
 from hapic.decorator import ExceptionHandlerControllerWrapper
17
+from hapic.decorator import AsyncExceptionHandlerControllerWrapper
17 18
 from hapic.decorator import InputBodyControllerWrapper
18 19
 from hapic.decorator import AsyncInputBodyControllerWrapper
19 20
 from hapic.decorator import InputHeadersControllerWrapper
@@ -405,12 +406,21 @@ class Hapic(object):
405 406
         context = context or self._context_getter
406 407
         error_builder = error_builder or self._error_builder_getter
407 408
 
408
-        decoration = ExceptionHandlerControllerWrapper(
409
-            handled_exception_class,
410
-            context,
411
-            error_builder=error_builder,
412
-            http_code=http_code,
413
-        )
409
+        if self._async:
410
+            decoration = AsyncExceptionHandlerControllerWrapper(
411
+                handled_exception_class,
412
+                context,
413
+                error_builder=error_builder,
414
+                http_code=http_code,
415
+            )
416
+
417
+        else:
418
+            decoration = ExceptionHandlerControllerWrapper(
419
+                handled_exception_class,
420
+                context,
421
+                error_builder=error_builder,
422
+                http_code=http_code,
423
+            )
414 424
 
415 425
         def decorator(func):
416 426
             self._buffer.errors.append(ErrorDescription(decoration))

+ 22 - 0
tests/ext/unit/test_aiohttp.py View File

@@ -187,6 +187,28 @@ class TestAiohttpExt(object):
187 187
                    'i': ['Missing data for required field.'],
188 188
                } == data.get('details')
189 189
 
190
+    async def test_aiohttp_handle_excpetion__ok__nominal_case(
191
+        self,
192
+        aiohttp_client,
193
+        loop,
194
+    ):
195
+        hapic = Hapic(async_=True)
196
+
197
+        @hapic.handle_exception(ZeroDivisionError, http_code=400)
198
+        async def hello(request):
199
+            1 / 0
200
+
201
+        app = web.Application(debug=True)
202
+        app.router.add_get('/', hello)
203
+        hapic.set_context(AiohttpContext(app))
204
+        client = await aiohttp_client(app)
205
+
206
+        resp = await client.get('/')
207
+        assert resp.status == 400
208
+
209
+        data = await resp.json()
210
+        assert 'division by zero' == data.get('message')
211
+
190 212
     async def test_aiohttp_output_stream__ok__nominal_case(
191 213
         self,
192 214
         aiohttp_client,