余弦相似度计算:通用的文本相似度评估方法,通过计算向量之间的夹角来衡量文本的相似度
总体实现思路为:
- 对两段文本进行分词,得到单词列表。
- 将两段文本的单词列表合并,并去除重复的单词,形成词汇表。
- 根据词汇表,将两段文本转换为向量表示。
- 使用余弦相似度公式计算两段文本的相似度。
首先引入maven依赖:
<!-- 中文分词器 -->
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.7.8</version>
</dependency>
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi-ooxml</artifactId>
<version>4.0.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version> <!-- 请检查是否有更新的版本 -->
</dependency>
因为我是从excel里读取标准答案和真正答案做相似度平均值计算,所以我也引入了poi依赖。
废话不多说,上代码和注释:
package com.xxx.zjtest.testtest.test;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealVector;
import org.apache.poi.ss.usermodel.*;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.*;
/**
* 余弦相似度计算:通用的文本相似度评估方法,通过计算向量之间的夹角来衡量文本的相似度
*
* 对两段文本进行分词,得到单词列表。
* 将两段文本的单词列表合并,并去除重复的单词,形成词汇表。
* 根据词汇表,将两段文本转换为向量表示。
* 使用余弦相似度公式计算两段文本的相似度。
*/
public class CosineSimilarityCalculatorWithoutVocabulary {
public static double calculateCosineSimilarity(String text1, String text2) {
// 分词
List<String> words1 = tokenize(text1);
List<String> words2 = tokenize(text2);
// 创建词汇表
Set<String> vocabulary = new HashSet<>(words1);
vocabulary.addAll(words2);
// 将文本转换为向量
RealVector vector1 = convertTextToVector(words1, vocabulary);
RealVector vector2 = convertTextToVector(words2, vocabulary);
// 计算余弦相似度
double dotProduct = vector1.dotProduct(vector2);
double norm1 = vector1.getNorm();
double norm2 = vector2.getNorm();
double similarity = dotProduct / (norm1 * norm2);
return similarity;
}
//使用hanlp包进行分词
private static List<String> tokenize(String text) {
List<String> words = new ArrayList<>();
Segment segment = HanLP.newSegment().enableCustomDictionary(false).enablePlaceRecognize(true).enableOrganizationRecognize(true);
for (Term term : segment.seg(text)) {
words.add(term.word); // 获取分词后的单词
}
return words;
}
private static RealVector convertTextToVector(List<String> words, Set<String> vocabulary) {
double[] vector = new double[vocabulary.size()];
// 计算词频
int index = 0;
for (String word : vocabulary) {
int count = Collections.frequency(words, word);
vector[index++] = count;
}
// 将数组转换为Apache Commons Math的RealVector对象
RealVector realVector = MatrixUtils.createRealVector(vector);
return realVector;
}
public static double calculateAverageCSCWV(List<String> referenceTexts, List<String> hypothesisTexts) {
if (referenceTexts.size() != hypothesisTexts.size()) {
throw new IllegalArgumentException("参考文本列表和假设文本列表的长度必须相等");
}
double totalRougeL = 0;
for (int i = 0; i < referenceTexts.size(); i++) {
totalRougeL += calculateCosineSimilarity(referenceTexts.get(i), hypothesisTexts.get(i));
}
return totalRougeL / referenceTexts.size();
}
//excel操作类
private static void readColumnData(String filePath, List<String> list, int columnIndex) {
try (FileInputStream file = new FileInputStream(new File(filePath))) {
// 创建Workbook实例
Workbook workbook = new XSSFWorkbook(file);
// 获取第一个工作表
Sheet sheet = workbook.getSheetAt(0);
// 遍历每一行
for (Row row : sheet) {
// 获取单元格,注意这里列索引是从0开始的
Cell cell = row.getCell(columnIndex);
if (cell != null && cell.getCellType() == CellType.STRING) {
list.add(cell.getStringCellValue());
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
String filePath1 = "D:\\项目资料\\ai\\训练资料\\StandardAnswer.xlsx";
String filePath2 = "D:\\项目资料\\ai\\训练资料\\AssAnswer.xlsx";
List<String> referenceTexts = new ArrayList<>();
List<String> hypothesisTexts = new ArrayList<>();
readColumnData(filePath1, referenceTexts, 0); // 读取第一个Excel的第一列 参考答案
readColumnData(filePath2, hypothesisTexts, 0); // 读取第二个Excel的第一列 大模型答案
double similarity = calculateAverageCSCWV(referenceTexts, hypothesisTexts);
System.out.println("平均余弦相似度:" + similarity);
}
}
思路和代码以及注释都有了,完结撒花o( ̄▽ ̄)ブ