手書き数字を k-means でクラスタリングしてみた
機械学習で良く使われる手書き数字のデータセット MNIST を使って、k-means 法でどれぐらい教師無しのクラスタリングできるか試してみました (Matlab)。結論から言うと、精度はぜんぜん高くないのですが (正答率 6 割ぐらい)、メモとして残しておきます (構造の違うデータだったらもっとよい精度が出るかも…汗)。
MNIST データセットには、28×28 ピクセルの手書き画像が、訓練用 6 万枚、テスト用 1 万枚、あとこれらの正解ラベルが入っています。
www.atmarkit.co.jp
Matlab のデータファイル形式 (.mat) に成型されたものが下にあったので、それを拝借 (mnist.mat)。
github.com
手順としては、手書き数字データを読み込んで、(そのままだと 28×28 = 784 次元あって計算量が大きいので) 主成分分析 PCA で次元を減らして、k-means でクラスタリング、最後に損失 (= 1-正解率) を計算、という流れです。
まず、MNIST データを読み込んで、uint8 (だと Matlab がエラーを吐くので) を double 型に変換。
load('mnist.mat') trainX = double(trainX); trainY = double(trainY); testX = double(testX); testY = double(testY);
次に、PCA で次元を減らします。寄与率 (explained variance) の累計が 80% になる次元まで使うことにすると、44 次元となりました (784 → 44 に次元削減)。
[~,newX,~,~,explained] = pca(trainX); % explained variance th = 80; cumExplained = cumsum(explained); ndim = find(cumExplained>th,1,'first'); % plot explained variance figure(1); clf plot(cumExplained); refline(0,th) box off xlabel('Number of PCs') ylabel('Cumulative contribution') title(sprintf('The first %u dimensions -> k-means',ndim))
上の 44 次元のデータを k-means でクラスタリングします。k-means はクラスタ数を指定する必要がありますが、0-9 の数字ということで、クラスタ数 10 を指定。
rng(1); % for reproducibility kmeansIdx = kmeans(newX(:,1:ndim),10,'Display','iter','MaxIter',1000);
第 2 主成分までプロットしてみて、クラスタリング結果の雰囲気をつかみます。左側は正解ラベルで、右側は k-means のクラスタリング結果です。なんとなく同じように分類できている気もしますが、あからさまに間違えているのもありますね (たとえば、左端の 1 のクラスタ (オレンジ) は、k-means では 2 個のクラスタ (オレンジと緑) に分かれてしまっているように見えます)。
figure(2); clf c = hsv(10); subplot(121); for ii=1:10 plot(newX(trainY==ii-1,1),newX(trainY==ii-1,2),'.','Color',c(ii,:)); hold on end legend(num2str((0:9)')); title('Original label') xlabel('PC1') ylabel('PC2') subplot(122); for ii=1:10 plot(newX(kmeansIdx==ii,1),newX(kmeansIdx==ii,2),'.','Color',c(ii,:)); hold on end legend(num2str((1:10)')); title('K-means clustering') xlabel('PC1') ylabel('PC2')
続いて、クラスタリングの誤答率 (損失) を計算します。k-means の結果 (kmeansIdx) は、実際の 0-9 の数字には対応していないので、どのクラスタにどの数字が多かったかを調べて、各クラスタに対応する数字 0-9 を見つけています。それから誤答率を計算。
idxEdge = -0.5:9.5; idxAxis = 0:9; for ii=1:10 idx = kmeansIdx==ii; N = histcounts( trainY(idx), idxEdge); [~,maxIdx] = max(N); cluValue(ii) = idxAxis(maxIdx); end cluIdx = nan(size(kmeansIdx)); for ii=1:10 cluIdx(kmeansIdx==ii) = cluValue(ii); end % loss L = sum(trainY(:)~=cluIdx)/length(trainY);
最後に、クラスタリング結果の一部を表示してみます。誤答率は 41% (正答率 59%)。半分近く間違えていますが、チャンスレベルの正答率 10% よりはかなり良いでしょう (汗)。一部の数字、0, 2, 6 などは結構よさそうな雰囲気です。1 のクラスタ (と 0 のクラスタも) は、上で見たように 2 つ出現してしまっていますね。
figure(3); clf for ii=1:10 subplot(10,1,ii) Xtmp = trainX(kmeansIdx==ii,:); im = []; for jj=1:20 im = [im reshape(Xtmp(jj,:),28,28)']; end imagesc(im) axis image off text(-10,14,num2str(cluValue(ii))) if ii==1 title(sprintf('Loss = %.2f',L)) end end colormap gray
この方法で、k-means に食べさせる次元数をさらに増やしても、精度はこれ以上ほとんど改善しませんでした。精度を上げたければ、もっと高級な手法を使う必要がありそうです。
今回は、Matlab の Statistics and Machine Learning Toolbox に入っている kmeans 関数を使いましたが、自分で k-means のアルゴリズムを書き下してみた記事はこちら ↓