21xrx.com
2025-04-15 04:20:27 Tuesday
文章检索 我的文章 写文章
KNN算法的Java实现代码
2023-10-02 08:37:25 深夜i     14     0
KNN算法 Java实现 代码

KNN算法是一种常用的机器学习算法,用于分类和回归问题。它基于样本之间的距离来判断新样本的类别,并且可以通过调整K值来控制最终的分类结果。在这篇文章中,我们将介绍KNN算法在Java中的实现代码。

首先,我们需要定义一个KNN类,用于存储训练样本和测试样本的数据。以下是一个简单的KNN类的代码示例:

import java.util.ArrayList;
import java.util.List;
public class KNN {
  private List<DataPoint> trainingSet;
  private List<DataPoint> testSet;
  private int k;
  public KNN(List<DataPoint> trainingSet, List<DataPoint> testSet, int k)
    this.trainingSet = trainingSet;
    this.testSet = testSet;
    this.k = k;
  
  public void classify() {
    for (DataPoint testPoint : testSet) {
      for (DataPoint trainPoint : trainingSet) {
        trainPoint.calculateDistance(testPoint);
      }
      trainingSet.sort((a, b) -> Float.compare(a.getDistance(), b.getDistance()));
      int classA = 0;
      int classB = 0;
      for (int i = 0; i < k; i++) {
        if (trainingSet.get(i).getLabel().equals("A")) {
          classA++;
        } else {
          classB++;
        }
      }
      if (classA > classB) {
        testPoint.setPredictedLabel("A");
      } else {
        testPoint.setPredictedLabel("B");
      }
    }
  }
}

在上面的代码中,我们定义了一个KNN类,它包含了训练样本集合(trainingSet)、测试样本集合(testSet)以及K值。classify()方法用于对测试样本进行分类。

接下来,我们需要定义一个DataPoint类,用于存储样本的特征值和标签,并且计算样本之间的距离。以下是一个简单的DataPoint类的代码示例:

public class DataPoint {
  private float[] features;
  private String label;
  private float distance;
  private String predictedLabel;
  public DataPoint(float[] features, String label)
    this.features = features;
    this.label = label;
  
  public void calculateDistance(DataPoint other) {
    float sum = 0;
    for (int i = 0; i < features.length; i++) {
      sum += Math.pow(features[i] - other.features[i], 2);
    }
    this.distance = (float) Math.sqrt(sum);
  }
  public float getDistance()
    return distance;
  
  public String getLabel()
    return label;
  
  public void setPredictedLabel(String predictedLabel)
    this.predictedLabel = predictedLabel;
  
  public String getPredictedLabel()
    return predictedLabel;
  
}

在上面的代码中,我们定义了一个DataPoint类,它包含了特征值数组(features)、标签(label)、距离(distance)以及预测的标签(predictedLabel)。calculateDistance()方法用于计算样本之间的距离。

最后,我们可以使用以下代码片段来使用KNN类实现一个简单的分类任务:

public class Main {
  public static void main(String[] args) {
    List<DataPoint> trainingSet = new ArrayList<>();
    trainingSet.add(new DataPoint(new float[] 2, "A"));
    trainingSet.add(new DataPoint(new float[] 3, "A"));
    trainingSet.add(new DataPoint(new float[]3, "B"));
    trainingSet.add(new DataPoint(new float[]4, "B"));
    List<DataPoint> testSet = new ArrayList<>();
    testSet.add(new DataPoint(new float[] 2.5f, ""));
    testSet.add(new DataPoint(new float[]3.5f, ""));
    KNN knn = new KNN(trainingSet, testSet, 3);
    knn.classify();
    for (DataPoint testPoint : testSet) {
      System.out.println("Predicted label: " + testPoint.getPredictedLabel());
    }
  }
}

上面的代码中,我们定义了一个训练样本集合(trainingSet)和一个测试样本集合(testSet),并使用KNN类对测试样本进行分类。最后,我们打印出预测的标签。

通过以上代码示例,我们可以看到KNN算法的Java实现代码。当然,这只是一个简单的实现,实际中可能需要考虑更多的因素和优化。希望这篇文章能对理解和使用KNN算法有所帮助。

  
  

评论区