调教AI给我写了一个KD树的算法

调教AI给我写了一个KD树的算法

码农世界 2024-05-23 前端 61 次浏览 0个评论

我不擅长C++,但是目前需要用C++写一个KD树的算法。首先我有一份点云数据,需要找给定坐标范围0.1mm内的所有点。

于是我开始问AI,他一开始给的答案,完全是错误的,但是我一步步给出反馈,告诉他的问题,他居然给出的正确的答案(我和python代码做了对比,python现成程的库实现,可作为结果进行参考)

下面,是我和AI的对话过程。

这看起来像是一系列三维空间中的点的坐标数据,每行表示一个点,包含了该点的 x、y、z 坐标值。如果你想要使用 KD 树对这些点进行空间检索或者其他操作,你可以使用 C++ 来读取这些数据,然后构建 KD 树。以下是一个简单的示例代码,可以帮助你读取这个文件并构建 KD 树:

#include 
#include 
#include 
#include 
#include 
// Define a structure to represent a point in 3D space
struct Point3D {
    double x, y, z;
};
// Define a node structure for the KD tree
struct KDNode {
    Point3D point;
    KDNode* left;
    KDNode* right;
    int axis; // 0 for x, 1 for y, 2 for z
    KDNode(const Point3D& p) : point(p), left(nullptr), right(nullptr), axis(0) {}
};
// Function to recursively build a KD tree from a list of points
KDNode* buildKDTree(std::vector& points, int depth = 0) {
    if (points.empty()) return nullptr;
    // Determine current axis
    int axis = depth % 3;
    // Sort points based on current axis
    if (axis == 0)
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.x < b.x; });
    else if (axis == 1)
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.y < b.y; });
    else
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.z < b.z; });
    // Find median
    size_t medianIdx = points.size() / 2;
    KDNode* node = new KDNode(points[medianIdx]);
    node->axis = axis;
    // Recursively build left and right subtrees
    std::vector leftPoints(points.begin(), points.begin() + medianIdx);
    std::vector rightPoints(points.begin() + medianIdx + 1, points.end());
    node->left = buildKDTree(leftPoints, depth + 1);
    node->right = buildKDTree(rightPoints, depth + 1);
    return node;
}
// Function to print a KD tree (for debugging)
void printKDTree(KDNode* root) {
    if (root) {
        std::cout << "(" << root->point.x << ", " << root->point.y << ", " << root->point.z << ")" << std::endl;
        printKDTree(root->left);
        printKDTree(root->right);
    }
}
int main() {
    std::ifstream file("point_cloud.txt");
    if (!file.is_open()) {
        std::cerr << "Error opening file." << std::endl;
        return 1;
    }
    std::vector points;
    std::string line;
    while (std::getline(file, line)) {
        std::istringstream iss(line);
        double x, y, z;
        char comma;
        if (iss >> x >> comma >> y >> comma >> z) {
            points.push_back({x, y, z});
        } else {
            std::cerr << "Error reading line: " << line << std::endl;
        }
    }
    file.close();
    // Build KD tree
    KDNode* root = buildKDTree(points);
    // Print KD tree (for debugging)
    printKDTree(root);
    // Cleanup
    // You need to implement a function to delete the KD tree nodes
    return 0;
}

请注意,这只是一个简单的示例,仅用于演示如何读取点云文件并构建 KD 树。在实际应用中,你可能需要添加更多的功能和错误检查,并在程序结束时正确地清理分配的内存。

要解决这个问题,我们可以使用 KD 树进行最近邻搜索。给定一个点,我们可以通过搜索 KD 树来找到距离该点最近的点,并检查这些点是否在给定的范围内。下面是一个修改后的示例代码,用于找到给定点附近一定范围内的所有点:

#include 
#include 
#include 
#include 
#include 
#include 
// Define a structure to represent a point in 3D space
struct Point3D {
    double x, y, z;
};
// Define a node structure for the KD tree
struct KDNode {
    Point3D point;
    KDNode* left;
    KDNode* right;
    int axis; // 0 for x, 1 for y, 2 for z
    KDNode(const Point3D& p) : point(p), left(nullptr), right(nullptr), axis(0) {}
};
// Function to recursively build a KD tree from a list of points
KDNode* buildKDTree(std::vector& points, int depth = 0) {
    if (points.empty()) return nullptr;
    // Determine current axis
    int axis = depth % 3;
    // Sort points based on current axis
    if (axis == 0)
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.x < b.x; });
    else if (axis == 1)
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.y < b.y; });
    else
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.z < b.z; });
    // Find median
    size_t medianIdx = points.size() / 2;
    KDNode* node = new KDNode(points[medianIdx]);
    node->axis = axis;
    // Recursively build left and right subtrees
    std::vector leftPoints(points.begin(), points.begin() + medianIdx);
    std::vector rightPoints(points.begin() + medianIdx + 1, points.end());
    node->left = buildKDTree(leftPoints, depth + 1);
    node->right = buildKDTree(rightPoints, depth + 1);
    return node;
}
// Function to calculate Euclidean distance between two points
double distance(const Point3D& a, const Point3D& b) {
    return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) + std::pow(a.z - b.z, 2));
}
// Function to perform range search in KD tree
void rangeSearch(KDNode* root, const Point3D& target, double radius, std::vector& result) {
    if (!root) return;
    // Calculate distance to current point
    double dist = distance(root->point, target);
    // If the current point is within the radius, add it to the result
    if (dist <= radius) {
        result.push_back(root->point);
    }
    // Recursively search left and right subtrees if they might contain points within the radius
    if (root->left && target.axis[root->axis] - radius < root->point.axis[root->axis]) {
        rangeSearch(root->left, target, radius, result);
    }
    if (root->right && target.axis[root->axis] + radius >= root->point.axis[root->axis]) {
        rangeSearch(root->right, target, radius, result);
    }
}
int main() {
    std::ifstream file("point_cloud.txt");
    if (!file.is_open()) {
        std::cerr << "Error opening file." << std::endl;
        return 1;
    }
    std::vector points;
    std::string line;
    while (std::getline(file, line)) {
        std::istringstream iss(line);
        double x, y, z;
        char comma;
        if (iss >> x >> comma >> y >> comma >> z) {
            points.push_back({x, y, z});
        } else {
            std::cerr << "Error reading line: " << line << std::endl;
        }
    }
    file.close();
    // Build KD tree
    KDNode* root = buildKDTree(points);
    // Define the target point and search radius
    Point3D target = {1.0, 2.0, 3.0};
    double radius = 0.1; // in this case, 0.1mm
    // Perform range search
    std::vector result;
    rangeSearch(root, target, radius, result);
    // Print the points within the range
    std::cout << "Points within " << radius << "mm of (" << target.x << ", " << target.y << ", " << target.z << "):" << std::endl;
    for (const auto& p : result) {
        std::cout << "(" << p.x << ", " << p.y << ", " << p.z << ")" << std::endl;
    }
    // Cleanup
    // You need to implement a function to delete the KD tree nodes
    return 0;
}

这段代码将找到距离给定点不超过 0.1mm 的所有点,并将它们打印出来。请确保在使用时,将实际的搜索半径(以毫米为单位)传递给 `rangeSearch` 函数。

你是对的,我的确犯了一个错误。在 `rangeSearch` 函数中,我没有正确地限制搜索的范围,导致了重复添加点的问题。让我来修复这个问题。

修改后的 `rangeSearch` 函数应该会在递归搜索左右子树时正确地限制搜索范围,避免了重复添加点的问题。以下是修复后的代码:

// Function to perform range search in KD tree
void rangeSearch(KDNode* root, const Point3D& target, double radius, std::vector& result) {
    if (!root) return;
    // Calculate distance to current point
    double dist = distance(root->point, target);
    // If the current point is within the radius, add it to the result
    if (dist <= radius) {
        result.push_back(root->point);
    }
    // Recursively search left and right subtrees if they might contain points within the radius
    if (root->left && (target.x - radius) <= root->point.x) {
        rangeSearch(root->left, target, radius, result);
    }
    if (root->right && (target.x + radius) >= root->point.x) {
        rangeSearch(root->right, target, radius, result);
    }
    if (root->left && (target.y - radius) <= root->point.y) {
        rangeSearch(root->left, target, radius, result);
    }
    if (root->right && (target.y + radius) >= root->point.y) {
        rangeSearch(root->right, target, radius, result);
    }
    if (root->left && (target.z - radius) <= root->point.z) {
        rangeSearch(root->left, target, radius, result);
    }
    if (root->right && (target.z + radius) >= root->point.z) {
        rangeSearch(root->right, target, radius, result);
    }
}

这次,我修正了对左右子树的递归调用,确保了在向下搜索时正确地限制了搜索范围。现在应该能够正确地找到给定点附近半径为0.1mm范围内的所有点。

我明白了问题所在。在 `rangeSearch` 函数中,我没有正确地检查当前节点的坐标与目标点的距离,以确定是否需要继续向下搜索。让我来修正这个问题,确保每个节点只在满足条件时才向下搜索。

// Function to perform range search in KD tree
void rangeSearch(KDNode* root, const Point3D& target, double radius, std::vector& result) {
    if (!root) return;
    // Calculate distance to current point
    double dist = distance(root->point, target);
    // If the current point is within the radius, add it to the result
    if (dist <= radius) {
        result.push_back(root->point);
    }
    // Determine current axis
    int axis = root->axis;
    // Recursively search left and right subtrees if they might contain points within the radius
    if (axis == 0) {
        if (root->left && (target.x - radius) <= root->point.x) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.x + radius) >= root->point.x) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else if (axis == 1) {
        if (root->left && (target.y - radius) <= root->point.y) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.y + radius) >= root->point.y) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else {
        if (root->left && (target.z - radius) <= root->point.z) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.z + radius) >= root->point.z) {
            rangeSearch(root->right, target, radius, result);
        }
    }
}

在这个修正版本中,我添加了对当前节点的坐标与目标点的距离的检查,以确定是否需要继续向下搜索。这样就可以确保每个节点只在满足条件时才向下搜索,从而避免了重复添加点的问题。现在应该能够正确地找到给定点附近半径为0.1mm范围内的所有点。

抱歉,打印结果的小数位数不够精确是因为我没有设置输出精度。让我修正这个问题,确保打印结果的小数位数和原始数据一致。

#include  // Add this include for setprecision
// Function to perform range search in KD tree
void rangeSearch(KDNode* root, const Point3D& target, double radius, std::vector& result) {
    if (!root) return;
    // Calculate distance to current point
    double dist = distance(root->point, target);
    // If the current point is within the radius, add it to the result
    if (dist <= radius) {
        result.push_back(root->point);
    }
    // Determine current axis
    int axis = root->axis;
    // Recursively search left and right subtrees if they might contain points within the radius
    if (axis == 0) {
        if (root->left && (target.x - radius) <= root->point.x) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.x + radius) >= root->point.x) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else if (axis == 1) {
        if (root->left && (target.y - radius) <= root->point.y) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.y + radius) >= root->point.y) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else {
        if (root->left && (target.z - radius) <= root->point.z) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.z + radius) >= root->point.z) {
            rangeSearch(root->right, target, radius, result);
        }
    }
}
int main() {
    // Build KD tree and other code...
    // Print the points within the range
    std::cout << std::fixed << std::setprecision(6); // Set precision to 6 decimal places
    std::cout << "Points within " << radius << "mm of (" << target.x << ", " << target.y << ", " << target.z << "):" << std::endl;
    for (const auto& p : result) {
        std::cout << "(" << p.x << ", " << p.y << ", " << p.z << ")" << std::endl;
    }
    // Cleanup and other code...
}

通过使用 `std::fixed` 和 `std::setprecision(6)`,我设置了输出的精度为 6 位小数,确保了打印结果的小数位数与原始数据一致。

python的实现(作为验证AI的答案是否正确)

from scipy.spatial import KDTree
import time
def read_points_from_file(filename):
    points = []
    with open(filename, 'r') as file:
        for line in file:
            # 解析每行数据,并转换为浮点数
            x, y, z = map(float, line.strip().split(','))
            points.append((x, y, z))
    return points
def find_points_in_range(kd_tree, target_point, radius):
    # 使用 KD 树的 query_ball_point 方法搜索指定范围内的所有点
    indices = kd_tree.query_ball_point(target_point, radius)
    return [kd_tree.data[i] for i in indices]
def main():
    start_time = time.time()
    #filename = "D:/test_pro/old/CppTest/算法测试/test0/x64/Debug/test.txt";
    filename = "D:/test_pro/old/CppTest/算法测试/test0/x64/Debug/20240502.txt";
    points = read_points_from_file(filename)
    end_time = time.time()
    duration = (end_time - start_time) * 1000  # 将秒转换为毫秒
    print("文件读取时间:", duration, "milliseconds")
    start_time = time.time()
    kd_tree = KDTree(points)
    end_time = time.time()
    duration = (end_time - start_time) * 1000  # 将秒转换为毫秒
    print("构建kd树时间:", duration, "milliseconds")
    target_point = ( -141.874398,25.022764,2.484327)
    radius = 1
    start_time = time.time()
    points_in_range = find_points_in_range(kd_tree, target_point, radius)
    end_time = time.time()
    duration = (end_time - start_time) * 1000  # 将秒转换为毫秒
    print("查找时间:", duration, "milliseconds")
    print("一共有:", len(points_in_range), "个点")
    print("在指定范围内的点有:", points_in_range)
    for item in points_in_range:
         print(f'{item[0]},{item[1]},{item[2]}')
if __name__ == "__main__":
    main()

小结

 让我比较震惊的一点是,它无法直接给出正确的答案,但是如果能正确的指出它的错误,它居然就能一次次逼近正确,最终给出正确的答案。

转载请注明来自码农世界,本文标题:《调教AI给我写了一个KD树的算法》

百度分享代码,如果开启HTTPS请参考李洋个人博客
每一天,每一秒,你所做的决定都会改变你的人生!

发表评论

快捷回复:

评论列表 (暂无评论,61人围观)参与讨论

还没有评论,来说两句吧...

Top