Coverage for app/utils/decode_ydoc.py: 54%

128 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-02-19 12:47 +0000

1import base64 

2from pycrdt import Doc, XmlFragment, XmlElement, XmlText 

3from pycrdt._pycrdt import XmlText as IXmlText, XmlElement as IXmlElement, XmlFragment as IXmlFragment 

4import enum 

5from loguru import logger 

6from app.utils.typing import JSONValue, Optional 

7 

8class MultiValueEnum(enum.Enum): 

9 _aliases_: set[str] 

10 

11 def __new__(cls, *values: str): 

12 # obj gets created once per member 

13 obj = object.__new__(cls) 

14 # we choose the *first* of the tuple to be the “.value” 

15 obj._value_ = values[0] 

16 # but remember *all* of them 

17 obj._aliases_ = set(values) 

18 return obj 

19 

20 @classmethod 

21 def _missing_(cls, value: object): 

22 # when you do DecodeType("xml"), if "xml" isn’t any .value, 

23 # we search through every member’s ._aliases_ 

24 for member in cls: 

25 if value in member._aliases_: 

26 return member 

27 # default behavior: raise the usual ValueError 

28 raise ValueError(f"{value!r} is not a valid {cls.__name__}") 

29 

30class DecodeType(MultiValueEnum): 

31 XML = ("XML", "xml", "XML_TEXT", "xml_text") 

32 PLAIN_TEXT = ("PLAIN", "plain", "PLAIN_TEXT", "plain_text") 

33 

34def get_plain_text_from_xmltext(element: XmlText) -> str: 

35 try: 

36 out_parts: list[str] = [] 

37 for chunk, attrs in element.diff(): 

38 # 1) Plain string chunk 

39 if isinstance(chunk, str): 

40 out_parts.append(chunk) 

41 continue 

42 

43 # 2) Wrap integrated types returned by diff() 

44 doc_obj = element.doc # type: ignore[attr-defined] 

45 wrapped_doc: Optional[XmlElement | XmlText | XmlFragment] = None 

46 if isinstance(chunk, IXmlText): 

47 wrapped_doc = XmlText(_doc=doc_obj, _integrated=chunk) 

48 if isinstance(chunk, IXmlElement): 

49 wrapped_doc = XmlElement(_doc=doc_obj, _integrated=chunk) 

50 if isinstance(chunk, IXmlFragment): 

51 wrapped_doc = XmlFragment(_doc=doc_obj, _integrated=chunk) 

52 

53 if wrapped_doc is not None: 

54 # 2a) embed XmlElement : Plate mention = tag/type + attr value 

55 if isinstance(wrapped_doc, XmlElement): 

56 tag = (wrapped_doc.tag or "").lower() 

57 t = wrapped_doc.attributes.get("type") 

58 is_mention = ("mention" in tag) or (t in ("mention", "mention_inline", "mention_input")) 

59 if is_mention: 

60 val = ( 

61 wrapped_doc.attributes.get("value") 

62 or wrapped_doc.attributes.get("label") 

63 or wrapped_doc.attributes.get("name") 

64 ) 

65 out_parts.append(f"{val}" if isinstance(val, str) and val else "") 

66 continue 

67 # Else: Text of its children 

68 out_parts.append("".join(get_plain_text_content(children) for children in wrapped_doc.children)) 

69 continue 

70 

71 # 2b) embed XmlText : sometimes the mention is encoded in the chunk's attrs 

72 elif isinstance(wrapped_doc, XmlText): 

73 t = wrapped_doc.attributes.get("type") 

74 if t in ("mention", "mention_inline", "mention_input"): 

75 val = ( 

76 wrapped_doc.attributes.get("value") 

77 or wrapped_doc.attributes.get("label") 

78 or wrapped_doc.attributes.get("name") 

79 ) 

80 out_parts.append(f"{val}" if isinstance(val, str) and val else "") 

81 continue 

82 # Else: its text (may contain other embeds) 

83 out_parts.append(get_plain_text_content(wrapped_doc)) 

84 continue 

85 

86 # 2c) embed Fragment 

87 else: 

88 out_parts.append("".join(get_plain_text_content(children) for children in wrapped_doc.children)) 

89 continue 

90 

91 # 3) fallback: if attrs dict exists and contains the mention 

92 if isinstance(attrs, dict): 

93 t = attrs.get("type") 

94 if t in ("mention", "mention_inline", "mention_input"): 

95 v = attrs.get("value") or attrs.get("label") or attrs.get("name") 

96 out_parts.append(f"{v}" if isinstance(v, str) and v else "") 

97 continue 

98 

99 # 4) Unknown -> ignore (or placeholder) 

100 out_parts.append("") 

101 

102 return "".join(out_parts) 

103 except Exception: 

104 # Fallback in case of error during diff processing 

105 return str(element) 

106 

107def get_plain_text_content(element: XmlElement | XmlText | XmlFragment | None) -> str: 

108 if isinstance(element, XmlText): 

109 return get_plain_text_from_xmltext(element) 

110 elif isinstance(element, XmlElement): 

111 return ''.join(get_plain_text_content(child) for child in element.children) 

112 elif isinstance(element, XmlFragment): 

113 return '\n'.join(get_plain_text_content(child) for child in element.children) 

114 return "" 

115 

116def decode_ydoc(base64_update: str, decode_as_plain_text: bool = False, text_field: str = 'content') -> str: 

117 if base64_update.startswith('ydoc:'): 

118 # Strip the 'ydoc:' prefix if present 

119 base64_update = base64_update[5:] 

120 else: 

121 return base64_update # Return original if not a ydoc 

122 

123 # 1) Base64 → bytes 

124 try: 

125 update = base64.b64decode(base64_update) 

126 except BaseException as e: 

127 logger.error(f"Base64 decoding failed: {e}") 

128 return base64_update 

129 

130 # 2) Creation of the document and application of the update 

131 doc: Doc[XmlFragment | XmlElement | XmlText] = Doc() 

132 try: 

133 doc.apply_update(update) 

134 except BaseException as e: 

135 logger.error(f"apply_update failed: {e}") 

136 return base64_update 

137 

138 # 3) List of available keys (roots) 

139 keys = list(doc.keys()) 

140 

141 found_texts: dict[str, str] = {} 

142 # 4) Extract XmlFragment for each key 

143 for key in keys: 

144 try: 

145 crdt = doc.get(key, type=XmlFragment) 

146 if isinstance(crdt, XmlFragment): 

147 # If XmlFragment, look for Text inside 

148 text_content = "" 

149 

150 if decode_as_plain_text: 

151 text_content = get_plain_text_content(crdt) 

152 else: 

153 # If not decoding as plain text, convert the XmlFragment to string 

154 text_content = str(crdt) 

155 

156 # logger.debug(f"Found text for key '{key}': {text_content}") 

157 

158 found_texts[key] = text_content 

159 else: 

160 raise TypeError(f"Expected XmlFragment, got {type(crdt).__name__}") 

161 except BaseException as e: 

162 logger.error(f"doc.get('{key}') raised: {e}") 

163 

164 # 5) Prioritize the requested text_field 

165 if text_field in found_texts: 

166 return found_texts[text_field] 

167 

168 # 6) Fallback to the first Text found 

169 if found_texts: 

170 first_key = next(iter(found_texts)) 

171 return found_texts[first_key] 

172 

173 logger.warning("No Text root found, returning original content") 

174 return base64_update 

175 

176def decode_attributes_inplace(final_result: JSONValue, decode: Optional[DecodeType] = None) -> JSONValue: 

177 # No decoding requested? 

178 if not decode: 

179 return final_result 

180 

181 # Not a list of records? 

182 if not isinstance(final_result, list): 

183 return final_result 

184 

185 # Decode as plain text? 

186 plain_text = (decode == DecodeType.PLAIN_TEXT) 

187 

188 # Decode the attributes 

189 for record in final_result: 

190 if not isinstance(record, dict): 

191 continue 

192 

193 attributes = record.get("attributes") 

194 if not isinstance(attributes, dict): 

195 continue 

196 

197 for key, value in attributes.items(): 

198 decoded_value = decode_ydoc(value, plain_text) if isinstance(value, str) else value 

199 attributes[key] = decoded_value 

200 

201 return final_result