package com.doumee.core.tsp; /** * 聚类分组 */ import com.doumee.core.utils.Constants; import com.doumee.dao.admin.request.SketchCateModel; import com.doumee.dao.business.model.JkSketchCustomer; import com.doumee.service.business.impl.JkSketchServiceImpl; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; public class Clustering { public static List clusterPoints(List points, double threshold) { List clusters = new ArrayList<>(); boolean[] visited = new boolean[points.size()]; int index =0; for (int i = 0; i < points.size(); i++) { if (!visited[i]) { List cluster = new ArrayList<>(); dfs(points, visited, cluster, i, threshold); SketchCateModel sketchCateModel = new SketchCateModel(); sketchCateModel.setCustomerList(cluster); sketchCateModel.setId(index); sketchCateModel.setStartPoint(cluster.get(0)); for (JkSketchCustomer c : cluster){ sketchCateModel.setTotalNum(Constants.formatBigdecimal(sketchCateModel.getTotalNum()).add(Constants.formatBigdecimal(c.getTotalNum()))); } sketchCateModel.setTotalCustomer(cluster.size()); clusters.add(sketchCateModel); } } // 打印每个聚类的点 for (int i = 0; i < clusters.size(); i++) { System.out.println("Cluster " + (i + 1) + ": " + clusters.get(i).getStartPoint().getName()+ ": " + clusters.get(i).getCustomerList().size()); } return clusters; } public static double distanceTo(JkSketchCustomer self, JkSketchCustomer other) { List distanceMapParamList =JkSketchServiceImpl.getListFromJsonStr(self.getDistanceJson()); DistanceMapParam param = JkSketchServiceImpl.getParamByCustomerIds( other.getId(),distanceMapParamList); if(param!=null && param.getDistance()!=0){//如果之前已经获取过 return (param.getDistance()); } return DistanceCalculator.calculateDistance(Constants.formatBigdecimal(self.getLatitude()).doubleValue() ,Constants.formatBigdecimal(self.getLongitude()).doubleValue() ,Constants.formatBigdecimal(other.getLatitude()).doubleValue() ,Constants.formatBigdecimal(other.getLongitude()).doubleValue()); } private static void dfs(List points, boolean[] visited, List cluster, int startIndex, double threshold) { visited[startIndex] = true; cluster.add(points.get(startIndex)); JkSketchCustomer startPoint = points.get(startIndex); for (int i = 0; i < points.size(); i++) { if (!visited[i]) { double distance = distanceTo(startPoint,points.get(i)); if (distance <= threshold) { dfs(points, visited, cluster, i, threshold); // 递归添加到聚类中 } } } } /** * 117°40′~118°44′、北纬30°19′~31°34′ * @param args */ public static void main(String[] args) { List points = new ArrayList<>(); for (int i = 0; i <3000; i++) { JkSketchCustomer a = new JkSketchCustomer(); a.setLatitude(new BigDecimal(30.19d+(30.54d-30.19d)*Math.random())); a.setLongitude(new BigDecimal(117.40+(117.74d-117.40d)*Math.random())); a.setName("客户"+i); points.add(a); } double threshold = 1000; // 设置距离阈值,超过这个距离就不属于同一聚类。 clusterPoints(points, threshold); } }