AIQ株式会社

K-Meansによる色のクラスタリング

はじめに

とある案件で画像の色の抽出を行う必要があったので、その時に利用したK-MeansによるRGBカラーのクラスタリング化についてご紹介します。

K-Meansとは

Wikipedia参照

K-平均法は、一般には以下のような流れで実装される。
データの数を n 、クラスタの数を K としておく。
1.各データ x_i(i=1… n) に対してランダムにクラスタを割り振る。
2.割り振ったデータをもとに各クラスタの中心 V_j(j=1… K) を計算する。計算は通常割り当てられたデータの各要素の算術平均が使用される。
3.各 x_i と各 V_j との距離を求め、x_i を最も近い中心のクラスタに割り当て直す。
4.上記の処理で全ての x_i のクラスタの割り当てが変化しなかった場合、あるいは変化量が事前に設定した一定の閾値を下回った場合に、収束したと判断して処理を終了する。そうでない場合は新しく割り振られたクラスタから V_j を再計算して上記の処理を繰り返す。

難しそうなことが書いていますが、pythonのscikit-learnを利用すると簡単に実装することが出来ます。

実装例(python3でJupyter notebook仕様)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import collections

# データの準備
x = np.array([np.random.normal(loc, 0.1, 100) for loc in np.repeat([1,2], 2)]).flatten()
y = np.array([np.random.normal(loc, 0.1, 100) for loc in np.tile([1,2], 2)]).flatten()

# 4つのクラスタに分割する
pred = KMeans(n_clusters=4).fit(np.c_[x,y])

# 分割後のクラスタのセンターを表示
pred.cluster_centers_

# 各クラスタの要素数
collections.Counter(pred.labels_)

# プロット
plt.scatter(x, y, c = pred.labels_, s = 30)
plt.scatter(pred.cluster_centers_[:,0], pred.cluster_centers_[:,1], c = "r", marker = "+", s = 100)
plt.show()

こんな結果が得られるはずです。
データが4つのグループにキレイに分けられている事が分かります。
kmeans


色のK-Means

色(今回はRGB値)をクラスタリングしたいと思います。
RGBなのでデータを3次元に拡張してあげると実現出来ます。

画像の読み込み

OpenCVで読み込んでRGBを取得します。

import cv2
import numpy as np
import matplotlib.pyplot as plt

img = cv2.imread("Figure_3.png")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(np.array(img))
plt.show()

kmeans4

次元の削減

次元を削減し、RGB値のarrayへ変換します。

flatten = img.reshape(-1,3)

K-Meansでクラスタリング

from sklearn.cluster import KMeans
pred = KMeans(n_clusters=4).fit(flatten)

クラスタごとにデータを分ける

out = zip(pred.labels_, flatten)
clu0 = np.array([data.tolist() for label, data in zip(pred.labels_,flatten) if label==0])
clu1 = np.array([data.tolist() for label, data in zip(pred.labels_,flatten) if label==1])
clu2 = np.array([data.tolist() for label, data in zip(pred.labels_,flatten) if label==2])
clu3 = np.array([data.tolist() for label, data in zip(pred.labels_,flatten) if label==3])

クラスタをクラスタの中心値の色でプロットする

from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D

# センターカラーの取得
color0 = pred.cluster_centers_[0] / 255
color1 = pred.cluster_centers_[1] / 255
color2 = pred.cluster_centers_[2] / 255
color3 = pred.cluster_centers_[3] / 255

# プロット
fig = pyplot.figure()
ax = Axes3D(fig)

# 軸ラベルの設定
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_zlabel("Z-axis")

# 表示範囲の設定
ax.set_xlim(0, 255)
ax.set_ylim(0, 255)
ax.set_zlim(0, 255)

ax.plot(clu0[:,0], clu0[:,1], clu0[:,2], "o", color=color0, ms=4, mew=0.5)
ax.plot(clu1[:,0], clu1[:,1], clu1[:,2], "o", color=color1, ms=4, mew=0.5)
ax.plot(clu2[:,0], clu2[:,1], clu2[:,2], "o", color=color2, ms=4, mew=0.5)
ax.plot(clu3[:,0], clu3[:,1], clu3[:,2], "o", color=color3, ms=4, mew=0.5)

pyplot.show()

kmeans2

1番大きいクラスタの色を確認

collections.Counter(pred.labels_)
pred.cluster_centers_[0]

kmens3

まとめ

色のクラスタリングに関してご紹介しましたが、実際には何個のクラスタに分割するのが最適か決定する必要があります。
その辺の話題に関しては、また別の機会にご紹介します。


さて、AIQでは、私達と東京か札幌で一緒に働ける仲間を募集しています。
詳しくはこちら

私達と一緒にを様々な業界の未来を変えていきませんか?