Coverage for app/utils/decode_ydoc.py: 54%
128 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-19 12:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-19 12:47 +0000
1import base64
2from pycrdt import Doc, XmlFragment, XmlElement, XmlText
3from pycrdt._pycrdt import XmlText as IXmlText, XmlElement as IXmlElement, XmlFragment as IXmlFragment
4import enum
5from loguru import logger
6from app.utils.typing import JSONValue, Optional
8class MultiValueEnum(enum.Enum):
9 _aliases_: set[str]
11 def __new__(cls, *values: str):
12 # obj gets created once per member
13 obj = object.__new__(cls)
14 # we choose the *first* of the tuple to be the “.value”
15 obj._value_ = values[0]
16 # but remember *all* of them
17 obj._aliases_ = set(values)
18 return obj
20 @classmethod
21 def _missing_(cls, value: object):
22 # when you do DecodeType("xml"), if "xml" isn’t any .value,
23 # we search through every member’s ._aliases_
24 for member in cls:
25 if value in member._aliases_:
26 return member
27 # default behavior: raise the usual ValueError
28 raise ValueError(f"{value!r} is not a valid {cls.__name__}")
30class DecodeType(MultiValueEnum):
31 XML = ("XML", "xml", "XML_TEXT", "xml_text")
32 PLAIN_TEXT = ("PLAIN", "plain", "PLAIN_TEXT", "plain_text")
34def get_plain_text_from_xmltext(element: XmlText) -> str:
35 try:
36 out_parts: list[str] = []
37 for chunk, attrs in element.diff():
38 # 1) Plain string chunk
39 if isinstance(chunk, str):
40 out_parts.append(chunk)
41 continue
43 # 2) Wrap integrated types returned by diff()
44 doc_obj = element.doc # type: ignore[attr-defined]
45 wrapped_doc: Optional[XmlElement | XmlText | XmlFragment] = None
46 if isinstance(chunk, IXmlText):
47 wrapped_doc = XmlText(_doc=doc_obj, _integrated=chunk)
48 if isinstance(chunk, IXmlElement):
49 wrapped_doc = XmlElement(_doc=doc_obj, _integrated=chunk)
50 if isinstance(chunk, IXmlFragment):
51 wrapped_doc = XmlFragment(_doc=doc_obj, _integrated=chunk)
53 if wrapped_doc is not None:
54 # 2a) embed XmlElement : Plate mention = tag/type + attr value
55 if isinstance(wrapped_doc, XmlElement):
56 tag = (wrapped_doc.tag or "").lower()
57 t = wrapped_doc.attributes.get("type")
58 is_mention = ("mention" in tag) or (t in ("mention", "mention_inline", "mention_input"))
59 if is_mention:
60 val = (
61 wrapped_doc.attributes.get("value")
62 or wrapped_doc.attributes.get("label")
63 or wrapped_doc.attributes.get("name")
64 )
65 out_parts.append(f"{val}" if isinstance(val, str) and val else "")
66 continue
67 # Else: Text of its children
68 out_parts.append("".join(get_plain_text_content(children) for children in wrapped_doc.children))
69 continue
71 # 2b) embed XmlText : sometimes the mention is encoded in the chunk's attrs
72 elif isinstance(wrapped_doc, XmlText):
73 t = wrapped_doc.attributes.get("type")
74 if t in ("mention", "mention_inline", "mention_input"):
75 val = (
76 wrapped_doc.attributes.get("value")
77 or wrapped_doc.attributes.get("label")
78 or wrapped_doc.attributes.get("name")
79 )
80 out_parts.append(f"{val}" if isinstance(val, str) and val else "")
81 continue
82 # Else: its text (may contain other embeds)
83 out_parts.append(get_plain_text_content(wrapped_doc))
84 continue
86 # 2c) embed Fragment
87 else:
88 out_parts.append("".join(get_plain_text_content(children) for children in wrapped_doc.children))
89 continue
91 # 3) fallback: if attrs dict exists and contains the mention
92 if isinstance(attrs, dict):
93 t = attrs.get("type")
94 if t in ("mention", "mention_inline", "mention_input"):
95 v = attrs.get("value") or attrs.get("label") or attrs.get("name")
96 out_parts.append(f"{v}" if isinstance(v, str) and v else "")
97 continue
99 # 4) Unknown -> ignore (or placeholder)
100 out_parts.append("")
102 return "".join(out_parts)
103 except Exception:
104 # Fallback in case of error during diff processing
105 return str(element)
107def get_plain_text_content(element: XmlElement | XmlText | XmlFragment | None) -> str:
108 if isinstance(element, XmlText):
109 return get_plain_text_from_xmltext(element)
110 elif isinstance(element, XmlElement):
111 return ''.join(get_plain_text_content(child) for child in element.children)
112 elif isinstance(element, XmlFragment):
113 return '\n'.join(get_plain_text_content(child) for child in element.children)
114 return ""
116def decode_ydoc(base64_update: str, decode_as_plain_text: bool = False, text_field: str = 'content') -> str:
117 if base64_update.startswith('ydoc:'):
118 # Strip the 'ydoc:' prefix if present
119 base64_update = base64_update[5:]
120 else:
121 return base64_update # Return original if not a ydoc
123 # 1) Base64 → bytes
124 try:
125 update = base64.b64decode(base64_update)
126 except BaseException as e:
127 logger.error(f"Base64 decoding failed: {e}")
128 return base64_update
130 # 2) Creation of the document and application of the update
131 doc: Doc[XmlFragment | XmlElement | XmlText] = Doc()
132 try:
133 doc.apply_update(update)
134 except BaseException as e:
135 logger.error(f"apply_update failed: {e}")
136 return base64_update
138 # 3) List of available keys (roots)
139 keys = list(doc.keys())
141 found_texts: dict[str, str] = {}
142 # 4) Extract XmlFragment for each key
143 for key in keys:
144 try:
145 crdt = doc.get(key, type=XmlFragment)
146 if isinstance(crdt, XmlFragment):
147 # If XmlFragment, look for Text inside
148 text_content = ""
150 if decode_as_plain_text:
151 text_content = get_plain_text_content(crdt)
152 else:
153 # If not decoding as plain text, convert the XmlFragment to string
154 text_content = str(crdt)
156 # logger.debug(f"Found text for key '{key}': {text_content}")
158 found_texts[key] = text_content
159 else:
160 raise TypeError(f"Expected XmlFragment, got {type(crdt).__name__}")
161 except BaseException as e:
162 logger.error(f"doc.get('{key}') raised: {e}")
164 # 5) Prioritize the requested text_field
165 if text_field in found_texts:
166 return found_texts[text_field]
168 # 6) Fallback to the first Text found
169 if found_texts:
170 first_key = next(iter(found_texts))
171 return found_texts[first_key]
173 logger.warning("No Text root found, returning original content")
174 return base64_update
176def decode_attributes_inplace(final_result: JSONValue, decode: Optional[DecodeType] = None) -> JSONValue:
177 # No decoding requested?
178 if not decode:
179 return final_result
181 # Not a list of records?
182 if not isinstance(final_result, list):
183 return final_result
185 # Decode as plain text?
186 plain_text = (decode == DecodeType.PLAIN_TEXT)
188 # Decode the attributes
189 for record in final_result:
190 if not isinstance(record, dict):
191 continue
193 attributes = record.get("attributes")
194 if not isinstance(attributes, dict):
195 continue
197 for key, value in attributes.items():
198 decoded_value = decode_ydoc(value, plain_text) if isinstance(value, str) else value
199 attributes[key] = decoded_value
201 return final_result