第2098篇:知识图谱与LLM的融合——GraphRAG的工程实现
第2098篇:知识图谱与LLM的融合——GraphRAG的工程实现
适读人群:需要处理复杂关系查询的RAG系统工程师 | 阅读时长:约20分钟 | 核心价值:掌握知识图谱构建、图查询与向量检索的融合(GraphRAG),以及在Java中对接Neo4j的实现路径
标准RAG有一个经典失效场景:多跳推理。
用户问:"张三的直属上司的部门,今年有哪些在建项目?"
向量检索返回的是语义相似的文档片段,但这个问题涉及三层关系:张三→上司→部门→项目。单纯的向量检索很难把这条链路串起来。
知识图谱(Knowledge Graph)正好擅长处理关系型数据,而LLM擅长理解自然语言。GraphRAG把两者融合:用图查询找到相关实体和关系,再把这些结构化信息注入LLM的上下文,最终生成答案。
这篇文章把GraphRAG的核心实现路径讲清楚。
架构概览
Neo4j连接和基础操作
/**
* Neo4j连接配置
*
* 知识图谱数据库:Neo4j是最成熟的选择
* 支持Cypher查询语言,性能好,有Java Driver
*/
@Configuration
public class Neo4jConfig {
@Value("${neo4j.uri}")
private String uri; // bolt://localhost:7687
@Value("${neo4j.username}")
private String username;
@Value("${neo4j.password}")
private String password;
@Bean
public Driver neo4jDriver() {
return GraphDatabase.driver(
uri,
AuthTokens.basic(username, password),
Config.builder()
.withMaxConnectionPoolSize(50)
.withConnectionAcquisitionTimeout(5, TimeUnit.SECONDS)
.build()
);
}
@Bean
public Neo4jClient neo4jClient(Driver driver) {
return Neo4jClient.create(driver);
}
}
/**
* 知识图谱基础操作封装
*/
@Repository
@RequiredArgsConstructor
@Slf4j
public class KnowledgeGraphRepository {
private final Driver driver;
/**
* 查询实体的直接关系(一跳)
*/
public List<GraphRelation> findDirectRelations(String entityName, int limit) {
String cypher = """
MATCH (source {name: $name})-[r]->(target)
RETURN source.name AS source, type(r) AS relation, target.name AS target,
target.type AS targetType, properties(r) AS relProperties
LIMIT $limit
""";
try (Session session = driver.session()) {
return session.run(cypher, Map.of("name", entityName, "limit", limit))
.list(record -> new GraphRelation(
record.get("source").asString(),
record.get("relation").asString(),
record.get("target").asString(),
record.get("targetType").asString(""),
record.get("relProperties").asMap(Object::toString)
));
}
}
/**
* 多跳路径查询(指定实体间的最短路径)
*/
public List<GraphPath> findPaths(String fromEntity, String toEntity, int maxHops) {
String cypher = """
MATCH path = shortestPath((from {name: $from})-[*1..%d]-(to {name: $to}))
RETURN [node in nodes(path) | node.name] AS nodes,
[rel in relationships(path) | type(rel)] AS relations
LIMIT 5
""".formatted(maxHops);
try (Session session = driver.session()) {
return session.run(cypher, Map.of("from", fromEntity, "to", toEntity))
.list(record -> {
List<String> nodes = record.get("nodes").asList(Value::asString);
List<String> relations = record.get("relations").asList(Value::asString);
return new GraphPath(nodes, relations);
});
}
}
/**
* 按实体类型查询(比如查询某公司的所有员工)
*/
public List<Map<String, Object>> queryByRelationType(
String entityName, String entityType,
String relationshipType, String targetType) {
String cypher = """
MATCH (e:%s {name: $name})-[r:%s]->(target:%s)
RETURN target.name AS name, target.description AS description,
properties(r) AS relProps
ORDER BY target.name
""".formatted(entityType, relationshipType, targetType);
try (Session session = driver.session()) {
return session.run(cypher, Map.of("name", entityName))
.list(record -> new LinkedHashMap<>(Map.of(
"name", record.get("name").asString(""),
"description", record.get("description").asString(""),
"relProperties", record.get("relProps").asMap(Object::toString)
)));
}
}
/**
* 自由Cypher查询(用于复杂场景)
*/
public List<Map<String, Object>> executeQuery(String cypher, Map<String, Object> params) {
try (Session session = driver.session()) {
return session.run(cypher, params)
.list(record -> {
Map<String, Object> row = new LinkedHashMap<>();
record.keys().forEach(key -> {
Value value = record.get(key);
row.put(key, convertValue(value));
});
return row;
});
} catch (Exception e) {
log.error("Cypher查询失败: {}", e.getMessage());
throw new RuntimeException("图数据库查询失败", e);
}
}
private Object convertValue(Value value) {
return switch (value.type().name()) {
case "STRING" -> value.asString();
case "INTEGER" -> value.asLong();
case "FLOAT" -> value.asDouble();
case "BOOLEAN" -> value.asBoolean();
case "NULL" -> null;
default -> value.asObject();
};
}
public record GraphRelation(
String source, String relation, String target,
String targetType, Map<String, String> properties
) {}
public record GraphPath(List<String> nodes, List<String> relations) {
public String toReadableString() {
StringBuilder sb = new StringBuilder(nodes.get(0));
for (int i = 0; i < relations.size(); i++) {
sb.append(" -[").append(relations.get(i)).append("]-> ");
sb.append(nodes.get(i + 1));
}
return sb.toString();
}
}
}实体抽取:从问题到图查询
/**
* 从用户问题中抽取实体
*
* GraphRAG的第一步:知道用户在问什么实体
* 抽出来的实体用于图数据库查询
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class EntityExtractionService {
private final ChatLanguageModel llm;
private final Neo4jClient neo4jClient;
/**
* 从问题抽取实体(用于图查询)
*/
public EntityExtractionResult extract(String question) {
String prompt = """
从以下问题中抽取需要在知识图谱中查询的实体。
问题:%s
请返回JSON格式:
{
"entities": [
{"name": "实体名称", "type": "实体类型(Person/Organization/Project/Product等)"}
],
"relationships": ["可能涉及的关系类型,如:WORKS_FOR, BELONGS_TO等"],
"queryIntent": "问题的核心意图(一句话)"
}
只返回JSON,不要其他文字。
""".formatted(question);
try {
String response = llm.generate(prompt);
return parseEntityResult(response, question);
} catch (Exception e) {
log.warn("实体抽取失败,使用规则兜底: {}", e.getMessage());
return extractByRules(question);
}
}
/**
* 实体消歧:将抽取的名字和图谱中的实际节点对应
*
* 问题:用户说"张三",图谱里是"张三(工号12345)"
* 需要模糊匹配找到正确节点
*/
public List<ResolvedEntity> resolveEntities(List<ExtractedEntity> entities) {
return entities.stream()
.map(this::resolveOne)
.filter(e -> e != null)
.toList();
}
private ResolvedEntity resolveOne(ExtractedEntity entity) {
// 先精确匹配
String exactCypher = """
MATCH (n {name: $name})
RETURN n.name AS name, labels(n) AS labels, n.id AS id
LIMIT 1
""";
try (var result = neo4jClient.query(exactCypher)
.bind(entity.name()).to("name")
.fetch().one()) {
if (result.isPresent()) {
var row = result.get();
return new ResolvedEntity(entity.name(),
(String) row.get("name"),
entity.type(), 1.0);
}
} catch (Exception e) {
log.debug("精确匹配失败: {}", entity.name());
}
// 模糊匹配(全文索引)
String fuzzyCypher = """
CALL db.index.fulltext.queryNodes('entityIndex', $query)
YIELD node, score
WHERE score > 0.5
RETURN node.name AS name, labels(node) AS labels, score
ORDER BY score DESC
LIMIT 3
""";
try (var results = neo4jClient.query(fuzzyCypher)
.bind(entity.name()).to("query")
.fetch().all()) {
return results.stream()
.findFirst()
.map(row -> new ResolvedEntity(
entity.name(),
(String) row.get("name"),
entity.type(),
(Double) row.get("score")))
.orElse(null);
} catch (Exception e) {
log.warn("模糊匹配失败: {}", e.getMessage());
return null;
}
}
private EntityExtractionResult parseEntityResult(String json, String question) {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(extractJson(json));
List<ExtractedEntity> entities = new ArrayList<>();
for (JsonNode e : node.path("entities")) {
entities.add(new ExtractedEntity(
e.path("name").asText(),
e.path("type").asText()
));
}
List<String> relationships = new ArrayList<>();
for (JsonNode r : node.path("relationships")) {
relationships.add(r.asText());
}
String queryIntent = node.path("queryIntent").asText(question);
return new EntityExtractionResult(entities, relationships, queryIntent);
} catch (Exception e) {
return new EntityExtractionResult(List.of(), List.of(), question);
}
}
private EntityExtractionResult extractByRules(String question) {
// 兜底:基于引号提取(用户可能用引号标注实体)
List<ExtractedEntity> entities = new ArrayList<>();
java.util.regex.Matcher m = java.util.regex.Pattern
.compile("[\"「」'']([^\"「」'']+)[\"「」'']")
.matcher(question);
while (m.find()) {
entities.add(new ExtractedEntity(m.group(1), "Unknown"));
}
return new EntityExtractionResult(entities, List.of(), question);
}
private String extractJson(String s) {
int start = s.indexOf('{');
int end = s.lastIndexOf('}');
return (start >= 0 && end > start) ? s.substring(start, end + 1) : s;
}
public record ExtractedEntity(String name, String type) {}
public record ResolvedEntity(
String originalName, String resolvedName, String type, double confidence) {}
public record EntityExtractionResult(
List<ExtractedEntity> entities, List<String> relationships, String queryIntent) {}
}Cypher查询生成
/**
* 用LLM生成Cypher查询
*
* 比硬编码查询模板更灵活
* 需要给LLM提供图的schema信息
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class CypherQueryGenerator {
private final ChatLanguageModel llm;
private final GraphSchemaService schemaService;
/**
* 根据问题和已解析实体,生成Cypher查询
*/
public List<String> generateCypherQueries(
String question,
List<EntityExtractionService.ResolvedEntity> entities) {
// 获取图的schema(节点类型、关系类型)
String schemaDescription = schemaService.getSchemaDescription();
String entityContext = entities.stream()
.map(e -> String.format("- 实体:%s(图中对应:%s)",
e.originalName(), e.resolvedName()))
.collect(Collectors.joining("\n"));
String prompt = """
你是一个Neo4j Cypher查询专家。请根据用户问题生成Cypher查询。
**图数据库Schema**:
%s
**已识别的实体**:
%s
**用户问题**:
%s
**要求**:
1. 生成1-3个Cypher查询,覆盖问题的不同角度
2. 每个查询独立,都能单独执行
3. 查询要安全(不要DELETE/MERGE等写操作)
4. 限制返回数量(LIMIT 20以内)
5. 使用参数化查询
返回JSON格式:
{
"queries": [
{
"cypher": "MATCH ...",
"purpose": "这个查询的目的",
"params": {"param1": "value1"}
}
]
}
只返回JSON。
""".formatted(schemaDescription, entityContext, question);
try {
String response = llm.generate(prompt);
return parseCypherQueries(response);
} catch (Exception e) {
log.warn("Cypher生成失败: {}", e.getMessage());
return List.of();
}
}
private List<String> parseCypherQueries(String response) {
try {
String json = extractJson(response);
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(json);
List<String> queries = new ArrayList<>();
for (JsonNode q : root.path("queries")) {
String cypher = q.path("cypher").asText("");
if (!cypher.isBlank()) {
// 安全检查:只允许读操作
String upperCypher = cypher.toUpperCase();
if (!upperCypher.contains("DELETE") &&
!upperCypher.contains("SET ") &&
!upperCypher.contains("MERGE") &&
!upperCypher.contains("CREATE")) {
queries.add(cypher);
}
}
}
return queries;
} catch (Exception e) {
log.warn("Cypher解析失败: {}", e.getMessage());
return List.of();
}
}
private String extractJson(String s) {
int start = s.indexOf('{');
int end = s.lastIndexOf('}');
return (start >= 0 && end > start) ? s.substring(start, end + 1) : s;
}
}GraphRAG编排器
/**
* GraphRAG编排器
*
* 融合图查询结果和向量检索结果,生成最终答案
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class GraphRagOrchestrator {
private final EntityExtractionService entityExtractor;
private final CypherQueryGenerator cypherGenerator;
private final KnowledgeGraphRepository graphRepo;
private final EmbeddingStore vectorStore;
private final EmbeddingModel embeddingModel;
private final ChatLanguageModel llm;
/**
* GraphRAG查询主流程
*/
public GraphRagAnswer query(String userQuestion) {
log.info("GraphRAG查询: {}", userQuestion);
// Step 1: 实体抽取
var extractionResult = entityExtractor.extract(userQuestion);
var resolvedEntities = entityExtractor.resolveEntities(extractionResult.entities());
log.debug("抽取实体: {}, 解析成功: {}",
extractionResult.entities().size(), resolvedEntities.size());
// Step 2: 图查询(结构化关系信息)
String graphContext = fetchGraphContext(userQuestion, resolvedEntities);
// Step 3: 向量检索(非结构化文本信息)
String vectorContext = fetchVectorContext(userQuestion);
// Step 4: 融合上下文,生成答案
String answer = generateAnswer(userQuestion, graphContext, vectorContext,
extractionResult.queryIntent());
return new GraphRagAnswer(answer, graphContext, vectorContext, resolvedEntities);
}
private String fetchGraphContext(String question,
List<EntityExtractionService.ResolvedEntity> entities) {
if (entities.isEmpty()) {
log.debug("无已解析实体,跳过图查询");
return "";
}
StringBuilder context = new StringBuilder();
context.append("=== 知识图谱信息 ===\n");
// 1. 直接关系查询(每个实体的相关节点)
for (var entity : entities) {
List<KnowledgeGraphRepository.GraphRelation> relations =
graphRepo.findDirectRelations(entity.resolvedName(), 20);
if (!relations.isEmpty()) {
context.append("\n【").append(entity.resolvedName()).append("】的关系:\n");
relations.forEach(r -> context.append(String.format(
" %s -[%s]-> %s\n", r.source(), r.relation(), r.target())));
}
}
// 2. LLM生成的Cypher查询(覆盖多跳查询)
List<String> cyphers = cypherGenerator.generateCypherQueries(question, entities);
for (String cypher : cyphers) {
try {
List<Map<String, Object>> results = graphRepo.executeQuery(cypher, Map.of());
if (!results.isEmpty()) {
context.append("\n查询结果:\n");
results.forEach(row -> context.append(" ").append(row).append("\n"));
}
} catch (Exception e) {
log.warn("Cypher执行失败: {}", e.getMessage());
}
}
return context.toString();
}
private String fetchVectorContext(String question) {
// 向量检索(补充非结构化文本信息)
Embedding questionEmbedding = embeddingModel.embed(question).content();
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(questionEmbedding)
.maxResults(5)
.minScore(0.7)
.build();
List<EmbeddingMatch<TextSegment>> matches = vectorStore.search(searchRequest).matches();
if (matches.isEmpty()) return "";
StringBuilder context = new StringBuilder();
context.append("=== 相关文档信息 ===\n");
matches.forEach(match -> context.append("\n")
.append(match.embedded().text())
.append("\n"));
return context.toString();
}
private String generateAnswer(
String question, String graphContext,
String vectorContext, String queryIntent) {
boolean hasGraphContext = !graphContext.isEmpty();
boolean hasVectorContext = !vectorContext.isEmpty();
if (!hasGraphContext && !hasVectorContext) {
return "抱歉,我在知识库中没有找到与您问题相关的信息。";
}
String contextSection = "";
if (hasGraphContext) contextSection += graphContext + "\n";
if (hasVectorContext) contextSection += vectorContext + "\n";
String prompt = """
请基于以下知识库信息回答用户问题。
**知识库信息**:
%s
**用户问题**:
%s
**问题意图**:
%s
**回答要求**:
- 直接回答问题,不要重复问题
- 如果信息不完整,说明哪些信息是已知的、哪些是未知的
- 对于关系型信息(图谱数据),用清晰的结构表达
- 不要编造知识库中没有的内容
""".formatted(contextSection, question, queryIntent);
return llm.generate(prompt).trim();
}
public record GraphRagAnswer(
String answer, String graphContext, String vectorContext,
List<EntityExtractionService.ResolvedEntity> resolvedEntities
) {}
}知识图谱的构建与更新
/**
* 从文档中构建知识图谱
*
* 把非结构化文本中的实体和关系提取出来,写入Neo4j
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class KnowledgeGraphBuilder {
private final ChatLanguageModel llm;
private final Driver neo4jDriver;
/**
* 从文本提取知识三元组并写入图谱
*/
public KnowledgeExtractionResult extractAndStore(String text, String sourceDocument) {
// Step 1: LLM提取三元组
List<Triple> triples = extractTriples(text);
if (triples.isEmpty()) {
return new KnowledgeExtractionResult(0, 0, "无三元组提取结果");
}
// Step 2: 批量写入Neo4j
int storedCount = storeTriples(triples, sourceDocument);
log.info("知识提取完成: 提取三元组={}, 写入成功={}", triples.size(), storedCount);
return new KnowledgeExtractionResult(triples.size(), storedCount, "成功");
}
private List<Triple> extractTriples(String text) {
String prompt = """
从以下文本中提取知识三元组(主体-关系-客体)。
文本:
%s
提取规则:
1. 只提取明确陈述的关系,不推断
2. 实体名称标准化(去除称谓等修饰词)
3. 关系类型用英文(如:WORKS_FOR, BELONGS_TO, MANAGES等)
返回JSON:
{
"triples": [
{"subject": "实体1", "subjectType": "类型",
"relation": "RELATION_TYPE",
"object": "实体2", "objectType": "类型",
"properties": {"key": "value"}}
]
}
只返回JSON。
""".formatted(truncate(text, 3000));
try {
String response = llm.generate(prompt);
return parseTriples(response);
} catch (Exception e) {
log.warn("三元组提取失败: {}", e.getMessage());
return List.of();
}
}
private int storeTriples(List<Triple> triples, String sourceDocument) {
// 使用MERGE避免重复创建节点
String cypher = """
MERGE (s {name: $subjectName})
ON CREATE SET s.type = $subjectType, s.createdAt = datetime()
WITH s
MERGE (o {name: $objectName})
ON CREATE SET o.type = $objectType, o.createdAt = datetime()
WITH s, o
CALL apoc.merge.relationship(s, $relationshipType, {},
{source: $source, createdAt: datetime()}, o)
YIELD rel
RETURN count(rel) AS created
""";
int success = 0;
try (Session session = neo4jDriver.session()) {
for (Triple triple : triples) {
try {
session.run(cypher, Map.of(
"subjectName", triple.subject(),
"subjectType", triple.subjectType(),
"objectName", triple.object(),
"objectType", triple.objectType(),
"relationshipType", triple.relation(),
"source", sourceDocument
));
success++;
} catch (Exception e) {
log.warn("三元组写入失败: {}", triple, e);
}
}
}
return success;
}
private List<Triple> parseTriples(String json) {
try {
String cleanJson = json.substring(
json.indexOf('{'), json.lastIndexOf('}') + 1);
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(cleanJson);
List<Triple> result = new ArrayList<>();
for (JsonNode t : root.path("triples")) {
result.add(new Triple(
t.path("subject").asText(),
t.path("subjectType").asText("Entity"),
t.path("relation").asText(),
t.path("object").asText(),
t.path("objectType").asText("Entity")
));
}
return result;
} catch (Exception e) {
return List.of();
}
}
private String truncate(String text, int maxLen) {
return text.length() > maxLen ? text.substring(0, maxLen) + "..." : text;
}
public record Triple(
String subject, String subjectType, String relation,
String object, String objectType) {}
public record KnowledgeExtractionResult(
int extractedCount, int storedCount, String status) {}
}实践建议
什么时候需要GraphRAG
不是所有RAG场景都需要引入图数据库。GraphRAG适合的场景是:涉及多实体间的关系推理("A的上司的部门有哪些项目"),或者实体关系网络很重要(人员组织关系、产品依赖关系、知识体系)。如果你的问题都是"查文档找信息"而不是"推断关系",标准RAG就够了,不用引入图数据库的复杂性。
知识图谱的冷启动
构建知识图谱需要初始数据。通常有两种路径:一是从现有结构化数据(数据库、Excel)导入基础实体和关系;二是用LLM从现有文档中抽取三元组(就是上面的KnowledgeGraphBuilder)。两种路径结合效果最好:结构化数据保证核心实体准确,文档抽取补充关系信息。
实体解析是最大的坑
我们在实际项目里花了大量时间处理实体解析问题:同一个人可能有多种叫法(张三、张三三、张工),同名问题(两个"李明"),实体演变(公司改名了)。这些问题在小数据量时不明显,但随着图谱规模增长会越来越严重。建议一开始就为每个实体设计稳定的ID体系,用ID做主键而不是名称。
