Coverage for app/routers/search.py: 89%
122 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-19 12:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-19 12:47 +0000
1from fastapi import APIRouter
2from fastapi import Depends, Path, Query
3from fastapi.security import OAuth2PasswordBearer
4from neo4j import GraphDatabase, basic_auth
5from app.config.config import NEO4J_HOST, NEO4J_PORT, NEO4J_USER, NEO4J_PASSWORD, RESERVED_NODE_ATTRIBUTE_NAMES, RESERVED_NODE_LABEL_NAMES, RESERVED_LINK_ATTRIBUTE_NAMES
6from rapidfuzz import fuzz
8from app.utils.check_token import check_constellation_access, constellation_check_error
9from app.utils.response_format import generate_response
10from app.utils.decode_ydoc import decode_ydoc
11from app.utils.typing import JSONValue, cast, Optional, List, serialize_value
13router = APIRouter()
15# Définir le schéma OAuth2 pour récupérer le JWT
16oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
18# Define the Neo4j driver
19driver = GraphDatabase.driver(
20 f"neo4j://{NEO4J_HOST}:{NEO4J_PORT}",
21 auth=basic_auth(NEO4J_USER, NEO4J_PASSWORD))
23IGNORE_NODE_ATTRIBUTES = RESERVED_NODE_ATTRIBUTE_NAMES
24IGNORE_LINK_ATTRIBUTES = RESERVED_LINK_ATTRIBUTE_NAMES
27def score_global_key(item: JSONValue) -> float:
28 if isinstance(item, dict):
29 score = item.get("score")
30 if isinstance(score, dict):
31 g = score.get("global")
32 if isinstance(g, (int, float)):
33 return float(g)
34 return 0.0
36@router.get("/constellation/{constellation_uuid}/search/nodes",
37 summary="Search for data in a constellation",
38 description="Search for specific data within the constellation.",
39 response_description="All matching nodes in the constellation",
40 responses={
41 200: {
42 "description": "All matching nodes successfully returned",
43 "content": {
44 "application/json": {
45 "example": {
46 "success": True,
47 "data":
48 [
49 {
50 "attributes": {
51 "node_uuid": "0000-000000...",
52 "title": "example_title",
53 "content": "example_content"
54 },
55 "labels": [
56 "Node"
57 ],
58 "score": {
59 "attributes": {
60 "title": 95,
61 "content": 85
62 },
63 "labels": {
64 "Node": 100
65 }
66 }
67 }
68 ],
69 "message": "All matching nodes successfully returned",
70 "error": None
71 }
72 }
73 }
74 },
75 **constellation_check_error,
76 },
77 tags=["Search"]
78)
79async def search_node_constellation(constellation_uuid: str = Path(..., description="The UUID of the constellation to get the nodes from"),
80 search_query: str = Query(..., description="The search query to filter nodes by"),
81 limit: int = Query(default=100, ge=0, description="The maximum number of nodes to return"),
82 page: int = Query(default=1, ge=1, description="The page number to return"),
83 in_filter: Optional[List[str]] = Query(default=None, description="Filter to include only specific attributes in the final results"),
84 out_filter: Optional[List[str]] = Query(default=None, description="Filter to exclude specific attributes from the final results"),
85 token: str = Depends(oauth2_scheme)):
86 test = check_constellation_access(token, constellation_uuid, "READ")
87 if test is not None:
88 return test
90 if in_filter is None:
91 in_filter = []
92 if out_filter is None:
93 out_filter = []
95 skip = 0
97 if page != 1:
98 skip = (page - 1) * limit
100 # Lowercase the search query for case-insensitive comparison
101 search_query = search_query.lower()
103 final_result: JSONValue = []
105 with driver.session() as session:
106 result = session.run("""
107 MATCH (n {constellation_uuid: $constellation_uuid})
108 WHERE ALL(label IN $RESERVED_NODE_LABEL_NAMES WHERE NOT label IN labels(n))
109 RETURN properties(n) AS attributes, labels(n) as labels
110 """,
111 constellation_uuid=constellation_uuid,
112 RESERVED_NODE_LABEL_NAMES=RESERVED_NODE_LABEL_NAMES
113 )
114 final_result = [{"attributes": {
115 key: serialize_value(value)
116 for key, value in record["attributes"].items()
117 if (len(in_filter) == 0 or key in in_filter) and (len(out_filter) == 0 or key not in out_filter)
118 },"labels": record["labels"]} for record in result]
119 driver.close()
121 # Filter nodes based on the search query
122 node_length = len(final_result)
123 for index in range(node_length, 0, -1):
124 node = final_result[index - 1]
125 if not isinstance(node, dict) or not isinstance(node["attributes"], dict) or not isinstance(node["labels"], list):
126 final_result.remove(node)
127 continue
128 for key, value in node["attributes"].items():
129 decoded_value = decode_ydoc(value, True) if isinstance(value, str) else value
130 node["attributes"][key] = decoded_value if isinstance(decoded_value, str) else decoded_value
131 node_attributes_filtered = {
132 k: v for k, v in node["attributes"].items()
133 if k not in IGNORE_NODE_ATTRIBUTES
134 }
135 node_labels_filtered = [
136 label for label in node["labels"]
137 if isinstance(label, str)
138 ]
139 # Per attribute score
140 attributes_scores: dict[str, float] = {}
141 for key, value in node_attributes_filtered.items():
142 attributes_scores[key] = fuzz.partial_ratio(search_query, str(value).lower())
143 labels_scores: dict[str, float] = {}
144 for label in node_labels_filtered:
145 labels_scores[label] = fuzz.partial_ratio(search_query, str(label).lower())
147 # Remove the node if it doesn't contain the search query in any of its attributes keys/values or labels
148 if not any(value > 80 for value in attributes_scores.values()) and \
149 not any(value > 80 for value in labels_scores.values()):
150 final_result.remove(node)
151 continue
153 node["score"] = cast(JSONValue, {
154 "attributes": attributes_scores,
155 "labels": labels_scores,
156 # Mean score for the node
157 "global": (sum(attributes_scores.values()) + sum(labels_scores.values())) / \
158 (len(attributes_scores) + len(labels_scores)) if (len(attributes_scores) + len(labels_scores)) > 0 else 0.0
159 })
161 # Sort nodes by their global score in descending order
162 final_result.sort(key=score_global_key, reverse=True)
163 if limit > 0:
164 if skip < len(final_result):
165 # Limit the number of nodes returned
166 final_result = final_result[skip:skip + limit]
167 else:
168 # If skip is greater than the number of nodes, return an empty list
169 final_result = []
171 # Remove global score from nodes
172 for node in final_result:
173 if not isinstance(node, dict):
174 continue
175 score = node.get("score")
176 if isinstance(score, dict):
177 score.pop("global", None)
179 return generate_response(
180 status_code=200,
181 data=final_result,
182 message="All matching nodes successfully returned"
183 )
186@router.get("/constellation/{constellation_uuid}/search/links",
187 summary="Search for data in a constellation",
188 description="Search for specific data within the constellation.",
189 response_description="All matching links in the constellation",
190 responses={
191 200: {
192 "description": "All matching links successfully returned",
193 "content": {
194 "application/json": {
195 "example": {
196 "success": True,
197 "data":
198 [
199 {
200 "start_node": "0000-000000...",
201 "end_node": "0000-000000...",
202 "link_type": "string",
203 "attributes": {
204 "link_uuid": "0000-000000...",
205 "title": "example_link_title",
206 "content": "example_link_content"
207 },
208 "score": {
209 "attributes": {
210 "title": 90,
211 "content": 80
212 },
213 "type": {
214 "string": 100
215 }
216 }
217 }
218 ],
219 "message": "All matching links successfully returned",
220 "error": None
221 }
222 }
223 }
224 },
225 **constellation_check_error,
226 },
227 tags=["Search"]
228)
229async def search_link_constellation(constellation_uuid: str = Path(..., description="The UUID of the constellation to get the links from"),
230 search_query: str = Query(..., description="The search query to filter links by"),
231 limit: int = Query(default=100, ge=0, description="The maximum number of links to return"),
232 page: int = Query(default=1, ge=1, description="The page number to return"),
233 in_filter: Optional[List[str]] = Query(default=None, description="Filter to include only specific attributes in the final results"),
234 out_filter: Optional[List[str]] = Query(default=None, description="Filter to exclude specific attributes from the final results"),
235 token: str = Depends(oauth2_scheme)):
236 test = check_constellation_access(token, constellation_uuid, "READ")
237 if test is not None:
238 return test
240 if in_filter is None:
241 in_filter = []
242 if out_filter is None:
243 out_filter = []
245 skip = 0
247 if page != 1:
248 skip = (page - 1) * limit
250 # Lowercase the search query for case-insensitive comparison
251 search_query = search_query.lower()
253 final_result: JSONValue = []
255 with driver.session() as session:
256 result = session.run("""
257 MATCH (n {constellation_uuid: $constellation_uuid})-[r {constellation_uuid: $constellation_uuid}]->(m {constellation_uuid: $constellation_uuid})
258 WHERE ALL(label IN $RESERVED_NODE_LABEL_NAMES WHERE NOT label IN labels(n)) AND ALL(label IN $RESERVED_NODE_LABEL_NAMES WHERE NOT label IN labels(m))
259 RETURN n,m,r, properties(r) AS attributes
260 """,
261 constellation_uuid=constellation_uuid,
262 RESERVED_NODE_LABEL_NAMES=RESERVED_NODE_LABEL_NAMES
263 )
264 final_result = [{"start_node": record["m"]["node_uuid"], "end_node": record["n"]["node_uuid"], "type": record["r"].type, "attributes": {
265 key: serialize_value(value)
266 for key, value in record["attributes"].items()
267 if (len(in_filter) == 0 or key in in_filter) and (len(out_filter) == 0 or key not in out_filter)
268 }} for record in result]
269 driver.close()
271 # Filter links based on the search query
272 link_length = len(final_result)
273 for index in range(link_length, 0, -1):
274 link = final_result[index - 1]
275 if not isinstance(link, dict) or not isinstance(link["attributes"], dict) or not isinstance(link["type"], str):
276 final_result.remove(link)
277 continue
278 for key, value in link["attributes"].items():
279 decoded_value = decode_ydoc(value, True) if isinstance(value, str) else value
280 link["attributes"][key] = decoded_value.lower() if isinstance(decoded_value, str) else decoded_value
281 link_attributes_filtered = {
282 k: v for k, v in link["attributes"].items()
283 if k not in IGNORE_LINK_ATTRIBUTES
284 }
285 # Per attribute score
286 attributes_scores: dict[str, float] = {}
287 for key, value in link_attributes_filtered.items():
288 attributes_scores[key] = fuzz.partial_ratio(search_query, str(value).lower())
289 type_scores: float = fuzz.partial_ratio(search_query, str(link["type"]).lower())
291 # Remove the link if it doesn't contain the search query in any of its attributes keys/values
292 if not any(value > 80 for value in attributes_scores.values()) and \
293 not type_scores > 80:
294 final_result.remove(link)
295 continue
297 link["score"] = cast(JSONValue, {
298 "attributes": attributes_scores,
299 "type": {link["type"]: type_scores},
300 # Mean score for the link
301 "global": (sum(attributes_scores.values()) + type_scores) / \
302 (len(attributes_scores) + 1) if (len(attributes_scores) + 1) > 0 else 0.0
303 })
305 # Sort links by their global score in descending order
306 final_result.sort(key=score_global_key, reverse=True)
307 if limit > 0:
308 if skip < len(final_result):
309 # Limit the number of links returned
310 final_result = final_result[skip:skip + limit]
311 else:
312 # If skip is greater than the number of links, return an empty list
313 final_result = []
315 # Remove global score from links
316 for link in final_result:
317 if not isinstance(link, dict):
318 continue
319 score = link.get("score")
320 if isinstance(score, dict):
321 score.pop("global", None)
323 return generate_response(
324 status_code=200,
325 data=final_result,
326 message="All matching links successfully returned"
327 )