Coverage for app/routers/search.py: 89%

122 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-02-19 12:47 +0000

1from fastapi import APIRouter 

2from fastapi import Depends, Path, Query 

3from fastapi.security import OAuth2PasswordBearer 

4from neo4j import GraphDatabase, basic_auth 

5from app.config.config import NEO4J_HOST, NEO4J_PORT, NEO4J_USER, NEO4J_PASSWORD, RESERVED_NODE_ATTRIBUTE_NAMES, RESERVED_NODE_LABEL_NAMES, RESERVED_LINK_ATTRIBUTE_NAMES 

6from rapidfuzz import fuzz 

7 

8from app.utils.check_token import check_constellation_access, constellation_check_error 

9from app.utils.response_format import generate_response 

10from app.utils.decode_ydoc import decode_ydoc 

11from app.utils.typing import JSONValue, cast, Optional, List, serialize_value 

12 

13router = APIRouter() 

14 

15# Définir le schéma OAuth2 pour récupérer le JWT 

16oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 

17 

18# Define the Neo4j driver 

19driver = GraphDatabase.driver( 

20 f"neo4j://{NEO4J_HOST}:{NEO4J_PORT}", 

21 auth=basic_auth(NEO4J_USER, NEO4J_PASSWORD)) 

22 

23IGNORE_NODE_ATTRIBUTES = RESERVED_NODE_ATTRIBUTE_NAMES 

24IGNORE_LINK_ATTRIBUTES = RESERVED_LINK_ATTRIBUTE_NAMES 

25 

26 

27def score_global_key(item: JSONValue) -> float: 

28 if isinstance(item, dict): 

29 score = item.get("score") 

30 if isinstance(score, dict): 

31 g = score.get("global") 

32 if isinstance(g, (int, float)): 

33 return float(g) 

34 return 0.0 

35 

36@router.get("/constellation/{constellation_uuid}/search/nodes", 

37 summary="Search for data in a constellation", 

38 description="Search for specific data within the constellation.", 

39 response_description="All matching nodes in the constellation", 

40 responses={ 

41 200: { 

42 "description": "All matching nodes successfully returned", 

43 "content": { 

44 "application/json": { 

45 "example": { 

46 "success": True, 

47 "data": 

48 [ 

49 { 

50 "attributes": { 

51 "node_uuid": "0000-000000...", 

52 "title": "example_title", 

53 "content": "example_content" 

54 }, 

55 "labels": [ 

56 "Node" 

57 ], 

58 "score": { 

59 "attributes": { 

60 "title": 95, 

61 "content": 85 

62 }, 

63 "labels": { 

64 "Node": 100 

65 } 

66 } 

67 } 

68 ], 

69 "message": "All matching nodes successfully returned", 

70 "error": None 

71 } 

72 } 

73 } 

74 }, 

75 **constellation_check_error, 

76 }, 

77 tags=["Search"] 

78) 

79async def search_node_constellation(constellation_uuid: str = Path(..., description="The UUID of the constellation to get the nodes from"), 

80 search_query: str = Query(..., description="The search query to filter nodes by"), 

81 limit: int = Query(default=100, ge=0, description="The maximum number of nodes to return"), 

82 page: int = Query(default=1, ge=1, description="The page number to return"), 

83 in_filter: Optional[List[str]] = Query(default=None, description="Filter to include only specific attributes in the final results"), 

84 out_filter: Optional[List[str]] = Query(default=None, description="Filter to exclude specific attributes from the final results"), 

85 token: str = Depends(oauth2_scheme)): 

86 test = check_constellation_access(token, constellation_uuid, "READ") 

87 if test is not None: 

88 return test 

89 

90 if in_filter is None: 

91 in_filter = [] 

92 if out_filter is None: 

93 out_filter = [] 

94 

95 skip = 0 

96 

97 if page != 1: 

98 skip = (page - 1) * limit 

99 

100 # Lowercase the search query for case-insensitive comparison 

101 search_query = search_query.lower() 

102 

103 final_result: JSONValue = [] 

104 

105 with driver.session() as session: 

106 result = session.run(""" 

107 MATCH (n {constellation_uuid: $constellation_uuid}) 

108 WHERE ALL(label IN $RESERVED_NODE_LABEL_NAMES WHERE NOT label IN labels(n)) 

109 RETURN properties(n) AS attributes, labels(n) as labels 

110 """, 

111 constellation_uuid=constellation_uuid, 

112 RESERVED_NODE_LABEL_NAMES=RESERVED_NODE_LABEL_NAMES 

113 ) 

114 final_result = [{"attributes": { 

115 key: serialize_value(value) 

116 for key, value in record["attributes"].items() 

117 if (len(in_filter) == 0 or key in in_filter) and (len(out_filter) == 0 or key not in out_filter) 

118 },"labels": record["labels"]} for record in result] 

119 driver.close() 

120 

121 # Filter nodes based on the search query 

122 node_length = len(final_result) 

123 for index in range(node_length, 0, -1): 

124 node = final_result[index - 1] 

125 if not isinstance(node, dict) or not isinstance(node["attributes"], dict) or not isinstance(node["labels"], list): 

126 final_result.remove(node) 

127 continue 

128 for key, value in node["attributes"].items(): 

129 decoded_value = decode_ydoc(value, True) if isinstance(value, str) else value 

130 node["attributes"][key] = decoded_value if isinstance(decoded_value, str) else decoded_value 

131 node_attributes_filtered = { 

132 k: v for k, v in node["attributes"].items() 

133 if k not in IGNORE_NODE_ATTRIBUTES 

134 } 

135 node_labels_filtered = [ 

136 label for label in node["labels"] 

137 if isinstance(label, str) 

138 ] 

139 # Per attribute score 

140 attributes_scores: dict[str, float] = {} 

141 for key, value in node_attributes_filtered.items(): 

142 attributes_scores[key] = fuzz.partial_ratio(search_query, str(value).lower()) 

143 labels_scores: dict[str, float] = {} 

144 for label in node_labels_filtered: 

145 labels_scores[label] = fuzz.partial_ratio(search_query, str(label).lower()) 

146 

147 # Remove the node if it doesn't contain the search query in any of its attributes keys/values or labels 

148 if not any(value > 80 for value in attributes_scores.values()) and \ 

149 not any(value > 80 for value in labels_scores.values()): 

150 final_result.remove(node) 

151 continue 

152 

153 node["score"] = cast(JSONValue, { 

154 "attributes": attributes_scores, 

155 "labels": labels_scores, 

156 # Mean score for the node 

157 "global": (sum(attributes_scores.values()) + sum(labels_scores.values())) / \ 

158 (len(attributes_scores) + len(labels_scores)) if (len(attributes_scores) + len(labels_scores)) > 0 else 0.0 

159 }) 

160 

161 # Sort nodes by their global score in descending order 

162 final_result.sort(key=score_global_key, reverse=True) 

163 if limit > 0: 

164 if skip < len(final_result): 

165 # Limit the number of nodes returned 

166 final_result = final_result[skip:skip + limit] 

167 else: 

168 # If skip is greater than the number of nodes, return an empty list 

169 final_result = [] 

170 

171 # Remove global score from nodes 

172 for node in final_result: 

173 if not isinstance(node, dict): 

174 continue 

175 score = node.get("score") 

176 if isinstance(score, dict): 

177 score.pop("global", None) 

178 

179 return generate_response( 

180 status_code=200, 

181 data=final_result, 

182 message="All matching nodes successfully returned" 

183 ) 

184 

185 

186@router.get("/constellation/{constellation_uuid}/search/links", 

187 summary="Search for data in a constellation", 

188 description="Search for specific data within the constellation.", 

189 response_description="All matching links in the constellation", 

190 responses={ 

191 200: { 

192 "description": "All matching links successfully returned", 

193 "content": { 

194 "application/json": { 

195 "example": { 

196 "success": True, 

197 "data": 

198 [ 

199 { 

200 "start_node": "0000-000000...", 

201 "end_node": "0000-000000...", 

202 "link_type": "string", 

203 "attributes": { 

204 "link_uuid": "0000-000000...", 

205 "title": "example_link_title", 

206 "content": "example_link_content" 

207 }, 

208 "score": { 

209 "attributes": { 

210 "title": 90, 

211 "content": 80 

212 }, 

213 "type": { 

214 "string": 100 

215 } 

216 } 

217 } 

218 ], 

219 "message": "All matching links successfully returned", 

220 "error": None 

221 } 

222 } 

223 } 

224 }, 

225 **constellation_check_error, 

226 }, 

227 tags=["Search"] 

228) 

229async def search_link_constellation(constellation_uuid: str = Path(..., description="The UUID of the constellation to get the links from"), 

230 search_query: str = Query(..., description="The search query to filter links by"), 

231 limit: int = Query(default=100, ge=0, description="The maximum number of links to return"), 

232 page: int = Query(default=1, ge=1, description="The page number to return"), 

233 in_filter: Optional[List[str]] = Query(default=None, description="Filter to include only specific attributes in the final results"), 

234 out_filter: Optional[List[str]] = Query(default=None, description="Filter to exclude specific attributes from the final results"), 

235 token: str = Depends(oauth2_scheme)): 

236 test = check_constellation_access(token, constellation_uuid, "READ") 

237 if test is not None: 

238 return test 

239 

240 if in_filter is None: 

241 in_filter = [] 

242 if out_filter is None: 

243 out_filter = [] 

244 

245 skip = 0 

246 

247 if page != 1: 

248 skip = (page - 1) * limit 

249 

250 # Lowercase the search query for case-insensitive comparison 

251 search_query = search_query.lower() 

252 

253 final_result: JSONValue = [] 

254 

255 with driver.session() as session: 

256 result = session.run(""" 

257 MATCH (n {constellation_uuid: $constellation_uuid})-[r {constellation_uuid: $constellation_uuid}]->(m {constellation_uuid: $constellation_uuid}) 

258 WHERE ALL(label IN $RESERVED_NODE_LABEL_NAMES WHERE NOT label IN labels(n)) AND ALL(label IN $RESERVED_NODE_LABEL_NAMES WHERE NOT label IN labels(m)) 

259 RETURN n,m,r, properties(r) AS attributes 

260 """, 

261 constellation_uuid=constellation_uuid, 

262 RESERVED_NODE_LABEL_NAMES=RESERVED_NODE_LABEL_NAMES 

263 ) 

264 final_result = [{"start_node": record["m"]["node_uuid"], "end_node": record["n"]["node_uuid"], "type": record["r"].type, "attributes": { 

265 key: serialize_value(value) 

266 for key, value in record["attributes"].items() 

267 if (len(in_filter) == 0 or key in in_filter) and (len(out_filter) == 0 or key not in out_filter) 

268 }} for record in result] 

269 driver.close() 

270 

271 # Filter links based on the search query 

272 link_length = len(final_result) 

273 for index in range(link_length, 0, -1): 

274 link = final_result[index - 1] 

275 if not isinstance(link, dict) or not isinstance(link["attributes"], dict) or not isinstance(link["type"], str): 

276 final_result.remove(link) 

277 continue 

278 for key, value in link["attributes"].items(): 

279 decoded_value = decode_ydoc(value, True) if isinstance(value, str) else value 

280 link["attributes"][key] = decoded_value.lower() if isinstance(decoded_value, str) else decoded_value 

281 link_attributes_filtered = { 

282 k: v for k, v in link["attributes"].items() 

283 if k not in IGNORE_LINK_ATTRIBUTES 

284 } 

285 # Per attribute score 

286 attributes_scores: dict[str, float] = {} 

287 for key, value in link_attributes_filtered.items(): 

288 attributes_scores[key] = fuzz.partial_ratio(search_query, str(value).lower()) 

289 type_scores: float = fuzz.partial_ratio(search_query, str(link["type"]).lower()) 

290 

291 # Remove the link if it doesn't contain the search query in any of its attributes keys/values 

292 if not any(value > 80 for value in attributes_scores.values()) and \ 

293 not type_scores > 80: 

294 final_result.remove(link) 

295 continue 

296 

297 link["score"] = cast(JSONValue, { 

298 "attributes": attributes_scores, 

299 "type": {link["type"]: type_scores}, 

300 # Mean score for the link 

301 "global": (sum(attributes_scores.values()) + type_scores) / \ 

302 (len(attributes_scores) + 1) if (len(attributes_scores) + 1) > 0 else 0.0 

303 }) 

304 

305 # Sort links by their global score in descending order 

306 final_result.sort(key=score_global_key, reverse=True) 

307 if limit > 0: 

308 if skip < len(final_result): 

309 # Limit the number of links returned 

310 final_result = final_result[skip:skip + limit] 

311 else: 

312 # If skip is greater than the number of links, return an empty list 

313 final_result = [] 

314 

315 # Remove global score from links 

316 for link in final_result: 

317 if not isinstance(link, dict): 

318 continue 

319 score = link.get("score") 

320 if isinstance(score, dict): 

321 score.pop("global", None) 

322 

323 return generate_response( 

324 status_code=200, 

325 data=final_result, 

326 message="All matching links successfully returned" 

327 )