share.py 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # coding: utf-8
  2. import pickle
  3. import typing
  4. import collections
  5. import redis
  6. from synergine2.exceptions import SynergineException
  7. from synergine2.exceptions import UnknownSharedData
  8. class SharedDataIndex(object):
  9. def __init__(
  10. self,
  11. shared_data_manager: 'SharedDataManager',
  12. key: str,
  13. ) -> None:
  14. self.shared_data_manager = shared_data_manager
  15. self.key = key
  16. def add(self, value: typing.Any) -> None:
  17. raise NotImplementedError()
  18. def remove(self, value: typing.Any) -> None:
  19. raise NotImplementedError()
  20. class TrackedDict(dict):
  21. base = dict
  22. def __init__(self, seq=None, **kwargs):
  23. self.key = kwargs.pop('key')
  24. self.original_key = kwargs.pop('original_key')
  25. self.shared = kwargs.pop('shared')
  26. super().__init__(seq, **kwargs)
  27. def __setitem__(self, key, value):
  28. super().__setitem__(key, value)
  29. self.shared.set(self.key, dict(self), original_key=self.original_key)
  30. def setdefault(self, k, d=None):
  31. v = super().setdefault(k, d)
  32. self.shared.set(self.key, dict(self), original_key=self.original_key)
  33. return v
  34. # TODO: Cover all methods
  35. class TrackedList(list):
  36. base = list
  37. def __init__(self, seq=(), **kwargs):
  38. self.key = kwargs.pop('key')
  39. self.original_key = kwargs.pop('original_key')
  40. self.shared = kwargs.pop('shared')
  41. super().__init__(seq)
  42. def append(self, p_object):
  43. super().append(p_object)
  44. self.shared.set(self.key, list(self), original_key=self.original_key)
  45. # TODO: Cover all methods
  46. class SharedDataManager(object):
  47. """
  48. This object is designed to own shared memory between processes. It must be feed (with set method) before
  49. start of processes. Processes will only be able to access shared memory filled here before start.
  50. """
  51. def __init__(self, clear: bool=True):
  52. self._r = redis.StrictRedis(host='localhost', port=6379, db=0) # TODO: configs
  53. self._data = {}
  54. self._modified_keys = set()
  55. self._default_values = {}
  56. self._special_types = {} # type: typing.Dict[str, typing.Union[typing.Type[TrackedDict], typing.Type[TrackedList]]] # nopep8
  57. if clear:
  58. self.clear()
  59. def clear(self) -> None:
  60. self._r.flushdb()
  61. self._data = {}
  62. self._modified_keys = set()
  63. def reset(self) -> None:
  64. for key, value in self._default_values.items():
  65. self.set(key, value)
  66. self.commit()
  67. self._data = {}
  68. def set(self, key: str, value: typing.Any, original_key: str=None) -> None:
  69. try:
  70. special_type, original_key_ = self._special_types[key]
  71. value = special_type(value, key=key, shared=self, original_key=original_key)
  72. except KeyError:
  73. try:
  74. # TODO: Code degeu pour gerer les {id}_truc
  75. special_type, original_key_ = self._special_types[original_key]
  76. value = special_type(value, key=key, shared=self, original_key=original_key)
  77. except KeyError:
  78. pass
  79. self._data[key] = value
  80. self._modified_keys.add((key, original_key))
  81. def get(self, *key_args: typing.Union[str, float, int]) -> typing.Any:
  82. key = '_'.join([str(v) for v in key_args])
  83. try:
  84. return self._data[key]
  85. except KeyError:
  86. b_value = self._r.get(key)
  87. if b_value is None:
  88. # We not allow None value storage
  89. raise UnknownSharedData('No shared data for key "{}"'.format(key))
  90. value = pickle.loads(b_value)
  91. special_type = None
  92. try:
  93. special_type, original_key = self._special_types[key]
  94. except KeyError:
  95. pass
  96. if special_type:
  97. self._data[key] = special_type(value, key=key, shared=self, original_key=original_key)
  98. else:
  99. self._data[key] = value
  100. return self._data[key]
  101. def commit(self) -> None:
  102. for key, original_key in self._modified_keys:
  103. try:
  104. special_type, original_key = self._special_types[key]
  105. value = special_type.base(self.get(key))
  106. self._r.set(key, pickle.dumps(value))
  107. except KeyError:
  108. # Code degeu pour gerer les {id}_truc
  109. try:
  110. special_type, original_key = self._special_types[original_key]
  111. value = special_type.base(self.get(key))
  112. self._r.set(key, pickle.dumps(value))
  113. except KeyError:
  114. self._r.set(key, pickle.dumps(self.get(key)))
  115. self._modified_keys = set()
  116. def refresh(self) -> None:
  117. self._data = {}
  118. def make_index(
  119. self,
  120. shared_data_index_class: typing.Type[SharedDataIndex],
  121. key: str,
  122. *args: typing.Any,
  123. **kwargs: typing.Any
  124. ) -> SharedDataIndex:
  125. return shared_data_index_class(self, key, *args, **kwargs)
  126. def create(
  127. self,
  128. key_args: typing.Union[str, typing.List[typing.Union[str, int, float]]],
  129. value: typing.Any,
  130. indexes: typing.List[SharedDataIndex]=None,
  131. ):
  132. key = key_args
  133. if not isinstance(key, str):
  134. key = '_'.join(key_args)
  135. indexes = indexes or []
  136. if type(value) is dict:
  137. value = TrackedDict(value, key=key, shared=shared, original_key=key)
  138. self._special_types[key] = TrackedDict, key
  139. elif type(value) is list:
  140. value = TrackedList(value, key=key, shared=shared, original_key=key)
  141. self._special_types[key] = TrackedList, key
  142. def get_key(obj):
  143. return key
  144. def get_key_with_id(obj):
  145. return key.format(id=obj.id)
  146. if '{id}' in key:
  147. key_formatter = get_key_with_id
  148. else:
  149. self.set(key, value)
  150. self._default_values[key] = value
  151. key_formatter = get_key
  152. def fget(self_):
  153. return self.get(key_formatter(self_))
  154. def fset(self_, value_):
  155. try:
  156. previous_value = self.get(key_formatter(self_))
  157. for index in indexes:
  158. index.remove(previous_value)
  159. except UnknownSharedData:
  160. pass # If no shared data, no previous value to remove
  161. self.set(key_formatter(self_), value_, original_key=key)
  162. for index in indexes:
  163. index.add(value_)
  164. def fdel(self_):
  165. raise SynergineException('You cannot delete a shared data')
  166. shared_property = property(
  167. fget=fget,
  168. fset=fset,
  169. fdel=fdel,
  170. )
  171. return shared_property
  172. # TODO: Does exist a way to permit overload of SharedDataManager class ?
  173. shared = SharedDataManager()
  174. class ListIndex(SharedDataIndex):
  175. def add(self, value):
  176. try:
  177. values = self.shared_data_manager.get(self.key)
  178. except UnknownSharedData:
  179. values = []
  180. values.append(value)
  181. self.shared_data_manager.set(self.key, values)
  182. def remove(self, value):
  183. values = self.shared_data_manager.get(self.key)
  184. values.remove(value)
  185. self.shared_data_manager.set(self.key, values)