package com.doumee.core.tsp;
|
|
/**
|
* 聚类分组
|
*/
|
import com.doumee.core.utils.Constants;
|
import com.doumee.dao.admin.request.SketchCateModel;
|
import com.doumee.dao.business.model.JkSketchCustomer;
|
import com.doumee.service.business.impl.JkSketchServiceImpl;
|
|
import java.math.BigDecimal;
|
import java.util.ArrayList;
|
import java.util.List;
|
|
public class Clustering {
|
public static List<SketchCateModel> clusterPoints(List<JkSketchCustomer> points, double threshold) {
|
|
List<SketchCateModel> clusters = new ArrayList<>();
|
boolean[] visited = new boolean[points.size()];
|
int index =0;
|
for (int i = 0; i < points.size(); i++) {
|
if (!visited[i]) {
|
List<JkSketchCustomer> cluster = new ArrayList<>();
|
dfs(points, visited, cluster, i, threshold);
|
SketchCateModel sketchCateModel = new SketchCateModel();
|
sketchCateModel.setCustomerList(cluster);
|
sketchCateModel.setId(index);
|
sketchCateModel.setStartPoint(cluster.get(0));
|
for (JkSketchCustomer c : cluster){
|
sketchCateModel.setTotalNum(Constants.formatBigdecimal(sketchCateModel.getTotalNum()).add(Constants.formatBigdecimal(c.getTotalNum())));
|
}
|
sketchCateModel.setTotalCustomer(cluster.size());
|
clusters.add(sketchCateModel);
|
}
|
}
|
// 打印每个聚类的点
|
for (int i = 0; i < clusters.size(); i++) {
|
System.out.println("Cluster " + (i + 1) + ": " + clusters.get(i).getStartPoint().getName()+ ": " + clusters.get(i).getCustomerList().size());
|
}
|
return clusters;
|
}
|
public static double distanceTo(JkSketchCustomer self, JkSketchCustomer other) {
|
List<DistanceMapParam> distanceMapParamList =JkSketchServiceImpl.getListFromJsonStr(self.getDistanceJson());
|
DistanceMapParam param = JkSketchServiceImpl.getParamByCustomerIds( other.getId(),distanceMapParamList);
|
if(param!=null && param.getDistance()!=0){//如果之前已经获取过
|
return (param.getDistance());
|
}
|
return DistanceCalculator.calculateDistance(Constants.formatBigdecimal(self.getLatitude()).doubleValue()
|
,Constants.formatBigdecimal(self.getLongitude()).doubleValue()
|
,Constants.formatBigdecimal(other.getLatitude()).doubleValue()
|
,Constants.formatBigdecimal(other.getLongitude()).doubleValue());
|
}
|
private static void dfs(List<JkSketchCustomer> points, boolean[] visited, List<JkSketchCustomer> cluster, int startIndex, double threshold) {
|
visited[startIndex] = true;
|
cluster.add(points.get(startIndex));
|
JkSketchCustomer startPoint = points.get(startIndex);
|
|
for (int i = 0; i < points.size(); i++) {
|
if (!visited[i]) {
|
double distance = distanceTo(startPoint,points.get(i));
|
if (distance <= threshold) {
|
dfs(points, visited, cluster, i, threshold); // 递归添加到聚类中
|
}
|
}
|
}
|
}
|
|
/**
|
* 117°40′~118°44′、北纬30°19′~31°34′
|
* @param args
|
*/
|
public static void main(String[] args) {
|
List<JkSketchCustomer> points = new ArrayList<>();
|
for (int i = 0; i <3000; i++) {
|
JkSketchCustomer a = new JkSketchCustomer();
|
a.setLatitude(new BigDecimal(30.19d+(30.54d-30.19d)*Math.random()));
|
a.setLongitude(new BigDecimal(117.40+(117.74d-117.40d)*Math.random()));
|
a.setName("客户"+i);
|
points.add(a);
|
}
|
|
double threshold = 1000; // 设置距离阈值,超过这个距离就不属于同一聚类。
|
clusterPoints(points, threshold);
|
}
|
}
|