26 lines
1.1 KiB
Diff
26 lines
1.1 KiB
Diff
From e5850e2df3918b17d029f796e71825bec0dcf0ed Mon Sep 17 00:00:00 2001
|
|
From: huangbin <huangbin58@huawei.com>
|
|
Date: Thu, 21 Nov 2024 18:29:56 +0800
|
|
Subject: [PATCH] fixbug wrong label for dbscan
|
|
|
|
---
|
|
.../slow_node_algs/space_comp_detector/sliding_window_dbscan.py | 2 ++
|
|
1 file changed, 2 insertions(+)
|
|
|
|
diff --git a/anteater/model/algorithms/slow_node_algs/space_comp_detector/sliding_window_dbscan.py b/anteater/model/algorithms/slow_node_algs/space_comp_detector/sliding_window_dbscan.py
|
|
index cccae33..0bace48 100644
|
|
--- a/anteater/model/algorithms/slow_node_algs/space_comp_detector/sliding_window_dbscan.py
|
|
+++ b/anteater/model/algorithms/slow_node_algs/space_comp_detector/sliding_window_dbscan.py
|
|
@@ -109,6 +109,8 @@ class SlidingWindowDBSCAN():
|
|
labels = dbscan.fit_predict(sim_scores)
|
|
logger.info(f"dnscan labels: {labels}")
|
|
label_counts = Counter(labels)
|
|
+ if -1 in label_counts:
|
|
+ label_counts.pop(-1)
|
|
# 找到样本数量最多的类别
|
|
most_common_label, _ = label_counts.most_common(1)[0]
|
|
new_labels = np.where(labels == most_common_label, 0, 1)
|
|
--
|
|
2.43.0
|
|
|