HyperAIHyperAI

Command Palette

Search for a command to run...

Flash-KMeans: 빠르고 메모리 효율적인 정확한 K-Means

초록

k-평균 (k-means) 알고리즘은 역사적으로 주로 오프라인 처리 원시 연산 (primitive) 으로 간주되어 왔으며, 데이터셋 구성 또는 임베딩 전처리 용도로 활용되는 데 그쳤고, 온라인 시스템의 핵심 구성 요소로는 거의 사용되지 않았습니다. 본 연구에서는 현대 AI 시스템 설계 관점에서 이 고전적 알고리즘을 재검토하여 k-평균을 온라인 원시 연산으로 재정의합니다. 우리는 기존 GPU 기반 k-평균 구현이 이론적 알고리즘 복잡도가 아닌 저수준 시스템 제약에 의해 근본적으로 병목 현상을 겪고 있음을 지적합니다. 구체적으로, 할당 (assignment) 단계는 대역폭이 넓은 메모리 (HBM) 에 N×K 차원의 거리 행렬을 명시적으로 대규모로 물리화 (materialization) 함으로써 심각한 입출력 (I/O) 병목에 시달립니다. 동시에, 중심점 (centroid) 갱신 단계는 불규칙한 산포 (scatter) 방식의 토큰 집계로 인해 발생하는 하드웨어 수준의 원자적 쓰기 충돌 (atomic write contention) 로 인해 심각한 성능 저하를 겪습니다. 이러한 성능 격차를 해소하기 위해 우리는 현대 GPU 워크로드를 대상으로 한 I/O 인식형 및 충돌 없는 k-평균 구현체인 'Flash-kmeans'를 제안합니다. Flash-kmeans 는 두 가지 핵심 커널 수준의 혁신을 도입합니다: (1) 거리 계산과 온라인 argmin 연산을 융합하여 중간 메모리 물리화를 완전히 우회하는 'FlashAssign'; (2) 높은 충돌을 유발하는 원자적 산포 연산을 대역폭이 넓은 세그먼트 수준의 지역적 축소 (reduction) 로 변환하기 위해 역 매핑 (inverse mapping) 을 명시적으로 구성하는 'sort-inverse update'입니다. 또한, 실용적인 배포 가능성을 보장하기 위해 청크 스트림 오버랩 (chunked-stream overlap) 및 캐시 인식 컴파일 휴리스틱과 같은 알고리즘 - 시스템 공동 설계 기법들을 통합했습니다. NVIDIA H200 GPU 를 통한 광범위한 평가 결과, Flash-kmeans 는 최선의 기준선 (baseline) 대비 최대 17.9 배의 종단 간 속도 향상을 달성했으며, 산업 표준 라이브러리인 cuML 및 FAISS 대비 각각 33 배 및 200 배 이상 우수한 성능을 보였습니다.

One-sentence Summary

Researchers from UC Berkeley, MIT, and UT Austin propose flash-kmeans, an IO-aware GPU implementation that eliminates distance matrix materialization and atomic contention via FlashAssign and sort-inverse update, delivering up to 17.9x speedup for scalable online clustering in modern AI workloads.

Key Contributions

  • Existing GPU implementations of k-means are hindered by severe IO bottlenecks from materializing massive distance matrices and hardware-level atomic contention during centroid updates, preventing their use as efficient online primitives.
  • Flash-KMeans addresses these issues with two core kernel innovations: FlashAssign, which fuses distance computation with online argmin to bypass intermediate memory storage, and sort-inverse update, which transforms high-contention atomic scatters into localized reductions.
  • Evaluations on NVIDIA H200 GPUs show up to a 17.9x end-to-end speedup over best baselines and over 200x improvement compared to FAISS, while enabling seamless out-of-core execution on up to one billion points.

Introduction

K-means clustering is evolving from an offline data processing tool into a critical online primitive for modern AI systems, including vector quantization, sparse routing in large language models, and generative video pipelines. Despite this shift, existing GPU implementations fail to meet latency requirements because they remain bottlenecked by hardware constraints rather than algorithmic complexity. Prior approaches suffer from severe memory bandwidth waste due to the explicit materialization of massive distance matrices and suffer from hardware-level serialization caused by atomic write contention during centroid updates. To address these issues, the authors introduce Flash-KMeans, an exact and IO-aware implementation that fuses distance computation with online argmin to bypass intermediate memory storage and replaces irregular atomic scatters with a sort-inverse update strategy for efficient aggregation. This system-level redesign eliminates key bottlenecks, delivering up to 17.9 times end-to-end speedup over baselines while enabling scalable execution on datasets exceeding one billion points.

Method

The authors introduce flash-kmeans, a highly optimized implementation designed to overcome the severe memory and synchronization bottlenecks inherent in standard GPU-based kkk-means clustering. The methodology focuses on restructuring the execution dataflow to eliminate IO overheads and resolve write-side contention without altering the underlying mathematical objective.

FlashAssign: Materialization-Free Assignment

To address the memory wall caused by materializing the massive distance matrix DRN×KD \in \mathbb{R}^{N \times K}DRN×K, the authors propose FlashAssign. This module fuses the distance computation and row-wise reduction into a single streaming procedure. Instead of writing the full distance matrix to High Bandwidth Memory (HBM) and reading it back, FlashAssign maintains running states for the minimum distance and corresponding centroid index directly in registers.

The process utilizes an online argmin update. For each data point, the kernel scans centroids in tiles. It computes local distances on-chip, identifies the local minimum within the tile, and compares it with the running minimum to update the global assignment. This approach ensures that the N×KN \times KN×K distance matrix is never explicitly constructed in memory.

As illustrated in the framework diagram above, the algorithm loops over centroid tiles. For each point XiX_iXi, it computes distances against a centroid block CjC_jCj, finds the local minimum, and updates the global minimum index. By employing two-dimensional tiling and asynchronous prefetching, the kernel hides memory latency while ensuring that the IO complexity is reduced from O(NK)O(NK)O(NK) to O(Nd+Kd)O(Nd + Kd)O(Nd+Kd).

Sort-Inverse Update: Low-Contention Aggregation

In the centroid update stage, standard implementations suffer from severe atomic write contention because multiple threads frequently attempt to update the same centroid simultaneously using scatter-style atomic additions. To resolve this, the authors propose the sort-inverse update strategy.

The core idea is to transform the token-to-cluster update into a cluster-to-token gather operation. The system first applies an argsort operation to the assignment vector aaa to obtain a permutation index. This reorders the tokens such that identical cluster IDs are grouped into contiguous segments.

The figure below contrasts the standard scatter-style update with the proposed sort-inverse approach. In the standard method (a), tokens are scattered irregularly, causing conflicts across multiple blocks. In the sort-inverse method (b), tokens are sorted by cluster ID, creating contiguous segments. This allows each Cooperative Thread Array (CTA) to process a chunk of the sorted sequence, gathering features from the original matrix and accumulating partial sums entirely in fast on-chip memory. Global atomic operations are only issued at segment boundaries.

This reorganization drastically reduces the number of atomic operations from O(Nd)O(Nd)O(Nd) to O((K+N/BN)d)O((K + \lceil N/B_N \rceil)d)O((K+N/BN⌉)d). As shown in the execution timeline (c), this eliminates the frequent stalls caused by atomic lock contention, enabling contention-free memory writes and significantly accelerating the reduction phase.

Algorithm-System Co-design

To ensure deployability in real systems, flash-kmeans incorporates several system-level optimizations. For large-scale data that exceeds GPU memory, the authors implement a chunked stream overlap design. This partitions data into chunks and uses CUDA streams to coordinate asynchronous host-to-device transfers with computation, following a double-buffer streaming pattern. Additionally, a cache-aware compile heuristic is employed to select high-quality kernel configurations based on hardware characteristics and problem shape, minimizing the time-to-first-run overhead typically associated with exhaustive tuning.

Experiment

  • Efficiency evaluations demonstrate that flash-kmeans consistently outperforms optimized baselines across diverse workloads, achieving up to 17.9× speedup in compute-intensive scenarios and 15.3× in highly batched settings while preventing out-of-memory failures in memory-intensive regimes.
  • Kernel-level analysis confirms that custom FlashAssign and Sort-Inverse Update modules effectively eliminate distance matrix materialization and atomic contention bottlenecks, delivering up to 21.2× and 6.3× speedups respectively.
  • Large-scale out-of-core experiments validate that the system successfully processes datasets up to one billion points by bounding peak memory usage, resulting in 6.3× to 10.5× faster iteration times compared to the most robust existing baseline.
  • Algorithm-system co-design tests show that a cache-aware compile heuristic reduces configuration search time by up to 175× compared to exhaustive tuning while maintaining near-optimal runtime performance with negligible degradation.

AI로 AI 구축

아이디어에서 출시까지 — 무료 AI 코코딩, 즉시 사용 가능한 환경, 최적의 GPU 가격으로 AI 개발을 가속화하세요.

AI 협업 코딩
바로 사용 가능한 GPU
최적의 가격

HyperAI Newsletters

최신 정보 구독하기
한국 시간 매주 월요일 오전 9시 에 이번 주의 최신 업데이트를 메일로 발송합니다
이메일 서비스 제공: MailChimp