21xrx.com
2024-11-22 01:47:04 Friday
登录
文章检索 我的文章 写文章
KNN算法的Java实现代码
2023-10-02 08:37:25 深夜i     --     --
KNN算法 Java实现 代码

KNN算法是一种常用的机器学习算法,用于分类和回归问题。它基于样本之间的距离来判断新样本的类别,并且可以通过调整K值来控制最终的分类结果。在这篇文章中,我们将介绍KNN算法在Java中的实现代码。

首先,我们需要定义一个KNN类,用于存储训练样本和测试样本的数据。以下是一个简单的KNN类的代码示例:


import java.util.ArrayList;

import java.util.List;

public class KNN {

  private List<DataPoint> trainingSet;

  private List<DataPoint> testSet;

  private int k;

  public KNN(List<DataPoint> trainingSet, List<DataPoint> testSet, int k)

    this.trainingSet = trainingSet;

    this.testSet = testSet;

    this.k = k;

  

  public void classify() {

    for (DataPoint testPoint : testSet) {

      for (DataPoint trainPoint : trainingSet) {

        trainPoint.calculateDistance(testPoint);

      }

      trainingSet.sort((a, b) -> Float.compare(a.getDistance(), b.getDistance()));

      int classA = 0;

      int classB = 0;

      for (int i = 0; i < k; i++) {

        if (trainingSet.get(i).getLabel().equals("A")) {

          classA++;

        } else {

          classB++;

        }

      }

      if (classA > classB) {

        testPoint.setPredictedLabel("A");

      } else {

        testPoint.setPredictedLabel("B");

      }

    }

  }

}

在上面的代码中,我们定义了一个KNN类,它包含了训练样本集合(trainingSet)、测试样本集合(testSet)以及K值。classify()方法用于对测试样本进行分类。

接下来,我们需要定义一个DataPoint类,用于存储样本的特征值和标签,并且计算样本之间的距离。以下是一个简单的DataPoint类的代码示例:


public class DataPoint {

  private float[] features;

  private String label;

  private float distance;

  private String predictedLabel;

  public DataPoint(float[] features, String label)

    this.features = features;

    this.label = label;

  

  public void calculateDistance(DataPoint other) {

    float sum = 0;

    for (int i = 0; i < features.length; i++) {

      sum += Math.pow(features[i] - other.features[i], 2);

    }

    this.distance = (float) Math.sqrt(sum);

  }

  public float getDistance()

    return distance;

  

  public String getLabel()

    return label;

  

  public void setPredictedLabel(String predictedLabel)

    this.predictedLabel = predictedLabel;

  

  public String getPredictedLabel()

    return predictedLabel;

  

}

在上面的代码中,我们定义了一个DataPoint类,它包含了特征值数组(features)、标签(label)、距离(distance)以及预测的标签(predictedLabel)。calculateDistance()方法用于计算样本之间的距离。

最后,我们可以使用以下代码片段来使用KNN类实现一个简单的分类任务:


public class Main {

  public static void main(String[] args) {

    List<DataPoint> trainingSet = new ArrayList<>();

    trainingSet.add(new DataPoint(new float[] 2, "A"));

    trainingSet.add(new DataPoint(new float[] 3, "A"));

    trainingSet.add(new DataPoint(new float[]3, "B"));

    trainingSet.add(new DataPoint(new float[]4, "B"));

    List<DataPoint> testSet = new ArrayList<>();

    testSet.add(new DataPoint(new float[] 2.5f, ""));

    testSet.add(new DataPoint(new float[]3.5f, ""));

    KNN knn = new KNN(trainingSet, testSet, 3);

    knn.classify();

    for (DataPoint testPoint : testSet) {

      System.out.println("Predicted label: " + testPoint.getPredictedLabel());

    }

  }

}

上面的代码中,我们定义了一个训练样本集合(trainingSet)和一个测试样本集合(testSet),并使用KNN类对测试样本进行分类。最后,我们打印出预测的标签。

通过以上代码示例,我们可以看到KNN算法的Java实现代码。当然,这只是一个简单的实现,实际中可能需要考虑更多的因素和优化。希望这篇文章能对理解和使用KNN算法有所帮助。

  
  

评论区

{{item['qq_nickname']}}
()
回复
回复