test_decorator.py 7.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # -*- coding: utf-8 -*-
  2. import typing
  3. from http import HTTPStatus
  4. import marshmallow
  5. from hapic.context import ContextInterface
  6. from hapic.data import HapicData
  7. from hapic.decorator import InputOutputControllerWrapper
  8. from hapic.decorator import ExceptionHandlerControllerWrapper
  9. from hapic.decorator import InputControllerWrapper
  10. from hapic.decorator import OutputControllerWrapper
  11. from hapic.hapic import ErrorResponseSchema
  12. from hapic.processor import RequestParameters
  13. from hapic.processor import MarshmallowOutputProcessor
  14. from hapic.processor import ProcessValidationError
  15. from hapic.processor import ProcessorInterface
  16. from tests.base import Base
  17. class MyContext(ContextInterface):
  18. def get_request_parameters(self, *args, **kwargs) -> RequestParameters:
  19. return RequestParameters(
  20. path_parameters={'fake': args},
  21. query_parameters={},
  22. body_parameters={},
  23. form_parameters={},
  24. header_parameters={},
  25. )
  26. def get_response(
  27. self,
  28. response: dict,
  29. http_code: int,
  30. ) -> typing.Any:
  31. return {
  32. 'original_response': response,
  33. 'http_code': http_code,
  34. }
  35. def get_validation_error_response(
  36. self,
  37. error: ProcessValidationError,
  38. http_code: HTTPStatus=HTTPStatus.BAD_REQUEST,
  39. ) -> typing.Any:
  40. return {
  41. 'original_error': error,
  42. 'http_code': http_code,
  43. }
  44. class MyProcessor(ProcessorInterface):
  45. def process(self, value):
  46. return value + 1
  47. def get_validation_error(
  48. self,
  49. request_context: RequestParameters,
  50. ) -> ProcessValidationError:
  51. return ProcessValidationError(
  52. error_details={
  53. 'original_request_context': request_context,
  54. },
  55. error_message='ERROR',
  56. )
  57. class MyControllerWrapper(InputOutputControllerWrapper):
  58. def before_wrapped_func(
  59. self,
  60. func_args: typing.Tuple[typing.Any, ...],
  61. func_kwargs: typing.Dict[str, typing.Any],
  62. ) -> typing.Union[None, typing.Any]:
  63. if func_args and func_args[0] == 666:
  64. return {
  65. 'error_response': 'we are testing'
  66. }
  67. func_kwargs['added_parameter'] = 'a value'
  68. def after_wrapped_function(self, response: typing.Any) -> typing.Any:
  69. return response * 2
  70. class MyInputControllerWrapper(InputControllerWrapper):
  71. def get_processed_data(
  72. self,
  73. request_parameters: RequestParameters,
  74. ) -> typing.Any:
  75. return {'we_are_testing': request_parameters.path_parameters}
  76. def update_hapic_data(
  77. self,
  78. hapic_data: HapicData,
  79. processed_data: typing.Dict[str, typing.Any],
  80. ) -> typing.Any:
  81. hapic_data.query = processed_data
  82. class MySchema(marshmallow.Schema):
  83. name = marshmallow.fields.String(required=True)
  84. class TestControllerWrapper(Base):
  85. def test_unit__base_controller_wrapper__ok__no_behaviour(self):
  86. context = MyContext()
  87. processor = MyProcessor()
  88. wrapper = InputOutputControllerWrapper(context, processor)
  89. @wrapper.get_wrapper
  90. def func(foo):
  91. return foo
  92. result = func(42)
  93. assert result == 42
  94. def test_unit__base_controller__ok__replaced_response(self):
  95. context = MyContext()
  96. processor = MyProcessor()
  97. wrapper = MyControllerWrapper(context, processor)
  98. @wrapper.get_wrapper
  99. def func(foo):
  100. return foo
  101. # see MyControllerWrapper#before_wrapped_func
  102. result = func(666)
  103. # result have been replaced by MyControllerWrapper#before_wrapped_func
  104. assert {'error_response': 'we are testing'} == result
  105. def test_unit__controller_wrapper__ok__overload_input(self):
  106. context = MyContext()
  107. processor = MyProcessor()
  108. wrapper = MyControllerWrapper(context, processor)
  109. @wrapper.get_wrapper
  110. def func(foo, added_parameter=None):
  111. # see MyControllerWrapper#before_wrapped_func
  112. assert added_parameter == 'a value'
  113. return foo
  114. result = func(42)
  115. # See MyControllerWrapper#after_wrapped_function
  116. assert result == 84
  117. class TestInputControllerWrapper(Base):
  118. def test_unit__input_data_wrapping__ok__nominal_case(self):
  119. context = MyContext()
  120. processor = MyProcessor()
  121. wrapper = MyInputControllerWrapper(context, processor)
  122. @wrapper.get_wrapper
  123. def func(foo, hapic_data=None):
  124. assert hapic_data
  125. assert isinstance(hapic_data, HapicData)
  126. # see MyControllerWrapper#before_wrapped_func
  127. assert hapic_data.query == {'we_are_testing': {'fake': (42,)}}
  128. return foo
  129. result = func(42)
  130. assert result == 42
  131. class TestOutputControllerWrapper(Base):
  132. def test_unit__output_data_wrapping__ok__nominal_case(self):
  133. context = MyContext()
  134. processor = MyProcessor()
  135. wrapper = OutputControllerWrapper(context, processor)
  136. @wrapper.get_wrapper
  137. def func(foo, hapic_data=None):
  138. # If no use of input wrapper, no hapic_data is given
  139. assert not hapic_data
  140. return foo
  141. result = func(42)
  142. # see MyProcessor#process
  143. assert {
  144. 'http_code': HTTPStatus.OK,
  145. 'original_response': 43,
  146. } == result
  147. def test_unit__output_data_wrapping__fail__error_response(self):
  148. context = MyContext()
  149. processor = MarshmallowOutputProcessor()
  150. processor.schema = MySchema()
  151. wrapper = OutputControllerWrapper(context, processor)
  152. @wrapper.get_wrapper
  153. def func(foo):
  154. return 'wrong result format'
  155. result = func(42)
  156. # see MyProcessor#process
  157. assert isinstance(result, dict)
  158. assert 'http_code' in result
  159. assert result['http_code'] == HTTPStatus.INTERNAL_SERVER_ERROR
  160. assert 'original_error' in result
  161. assert result['original_error'].error_details == {
  162. 'name': ['Missing data for required field.']
  163. }
  164. class TestExceptionHandlerControllerWrapper(Base):
  165. def test_unit__exception_handled__ok__nominal_case(self):
  166. context = MyContext()
  167. wrapper = ExceptionHandlerControllerWrapper(
  168. ZeroDivisionError,
  169. context,
  170. schema=ErrorResponseSchema(),
  171. http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
  172. )
  173. @wrapper.get_wrapper
  174. def func(foo):
  175. raise ZeroDivisionError('We are testing')
  176. response = func(42)
  177. assert 'http_code' in response
  178. assert response['http_code'] == HTTPStatus.INTERNAL_SERVER_ERROR
  179. assert 'original_response' in response
  180. assert response['original_response'] == {
  181. 'message': 'We are testing',
  182. 'code': None,
  183. 'detail': {},
  184. }
  185. def test_unit__exception_handled__ok__exception_error_dict(self):
  186. class MyException(Exception):
  187. def __init__(self, *args, **kwargs):
  188. super().__init__(*args, **kwargs)
  189. self.error_dict = {}
  190. context = MyContext()
  191. wrapper = ExceptionHandlerControllerWrapper(
  192. MyException,
  193. context,
  194. schema=ErrorResponseSchema(),
  195. http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
  196. )
  197. @wrapper.get_wrapper
  198. def func(foo):
  199. exc = MyException('We are testing')
  200. exc.error_detail = {'foo': 'bar'}
  201. raise exc
  202. response = func(42)
  203. assert 'http_code' in response
  204. assert response['http_code'] == HTTPStatus.INTERNAL_SERVER_ERROR
  205. assert 'original_response' in response
  206. assert response['original_response'] == {
  207. 'message': 'We are testing',
  208. 'code': None,
  209. 'detail': {'foo': 'bar'},
  210. }