user.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # -*- coding: utf-8 -*-
  2. import threading
  3. import cherrypy
  4. import transaction
  5. import tg
  6. import typing as typing
  7. from tracim.model.auth import User
  8. from tracim.model import DBSession
  9. CURRENT_USER_WEB = 'WEB'
  10. CURRENT_USER_WSGIDAV = 'WSGIDAV'
  11. class UserApi(object):
  12. def __init__(self, current_user: User):
  13. self._user = current_user
  14. def get_all(self):
  15. return DBSession.query(User).order_by(User.display_name).all()
  16. def _base_query(self):
  17. return DBSession.query(User)
  18. def get_one(self, user_id: int):
  19. return self._base_query().filter(User.user_id==user_id).one()
  20. def get_one_by_email(self, email: str):
  21. return self._base_query().filter(User.email==email).one()
  22. def get_one_by_id(self, id: int) -> User:
  23. return self._base_query().filter(User.user_id==id).one()
  24. def update(
  25. self,
  26. user: User,
  27. name: str=None,
  28. email: str=None,
  29. do_save=True,
  30. timezone: str='',
  31. ):
  32. if name is not None:
  33. user.display_name = name
  34. if email is not None:
  35. user.email = email
  36. user.timezone = timezone
  37. if do_save:
  38. self.save(user)
  39. if email and self._user and user.user_id==self._user.user_id:
  40. # this is required for the session to keep on being up-to-date
  41. tg.request.identity['repoze.who.userid'] = email
  42. tg.auth_force_login(email)
  43. def user_with_email_exists(self, email: str):
  44. try:
  45. self.get_one_by_email(email)
  46. return True
  47. except:
  48. return False
  49. def create_user(self, email=None, groups=[], save_now=False) -> User:
  50. user = User()
  51. if email:
  52. user.email = email
  53. for group in groups:
  54. user.groups.append(group)
  55. DBSession.add(user)
  56. if save_now:
  57. DBSession.flush()
  58. return user
  59. def save(self, user: User):
  60. DBSession.flush()
  61. def execute_created_user_actions(self, created_user: User) -> None:
  62. """
  63. Execute actions when user just been created
  64. :return:
  65. """
  66. # NOTE: Cyclic import
  67. from tracim.lib.calendar import CalendarManager
  68. from tracim.model.organisational import UserCalendar
  69. created_user.ensure_auth_token()
  70. # Ensure database is up-to-date
  71. DBSession.flush()
  72. transaction.commit()
  73. calendar_manager = CalendarManager(created_user)
  74. calendar_manager.create_then_remove_fake_event(
  75. calendar_class=UserCalendar,
  76. related_object_id=created_user.user_id,
  77. )
  78. class CurrentUserGetterInterface(object):
  79. def get_current_user(self) -> typing.Union[None, User]:
  80. raise NotImplementedError()
  81. class BaseCurrentUserGetter(CurrentUserGetterInterface):
  82. def __init__(self) -> None:
  83. self.api = UserApi(None)
  84. class WebCurrentUserGetter(BaseCurrentUserGetter):
  85. def get_current_user(self) -> typing.Union[None, User]:
  86. # HACK - D.A. - 2015-09-02
  87. # In tests, the tg.request.identity may not be set
  88. # (this is a buggy case, but for now this is how the software is;)
  89. if tg.request is not None:
  90. if hasattr(tg.request, 'identity'):
  91. if tg.request.identity is not None:
  92. return self.api.get_one_by_email(
  93. tg.request.identity['repoze.who.userid'],
  94. )
  95. return None
  96. class WsgidavCurrentUserGetter(BaseCurrentUserGetter):
  97. def get_current_user(self) -> typing.Union[None, User]:
  98. if hasattr(cherrypy.request, 'current_user_email'):
  99. return self.api.get_one_by_email(
  100. cherrypy.request.current_user_email,
  101. )
  102. return None
  103. class CurrentUserGetterApi(object):
  104. thread_local = threading.local()
  105. matches = {
  106. CURRENT_USER_WEB: WebCurrentUserGetter,
  107. CURRENT_USER_WSGIDAV: WsgidavCurrentUserGetter,
  108. }
  109. default = CURRENT_USER_WEB
  110. @classmethod
  111. def get_current_user(cls) -> User:
  112. try:
  113. return cls.thread_local.getter.get_current_user()
  114. except AttributeError:
  115. return cls.factory(cls.default).get_current_user()
  116. @classmethod
  117. def set_thread_local_getter(cls, name) -> None:
  118. if not hasattr(cls.thread_local, 'getter'):
  119. cls.thread_local.getter = cls.factory(name)
  120. @classmethod
  121. def factory(cls, name: str) -> CurrentUserGetterInterface:
  122. return cls.matches[name]()