機械学習の5、サポートベクターマシン

概要

本文には、大文字表現=行列/マトリクス、\(\boldsymbol{Bold}\)小文字表現=ベクトル\(R^d\)、普通小文字表現=スカラー\(R\) と記す。

複数の標本(Sample)が分類可能と仮定する場合、直線で異なるラベルに属する標本を分類して、線形SVMが適用される。しかし、このような直線が存在しない場合、カーネルトリック(Kernel Trick)を使用して、写像された高次元空間で標本分類用の超平面(Hyper Plain)が見つかり、非線形SVMが適用される。ロジスティック回帰とは違って、境界(Border)に最も近い標本(Support Vector)の間隔(Margin)を最大化するから、標本が境界にあることが拘束条件(s.t.)である。ラグランジュ乗数法(Lagrange’s method)を利用して拘束条件つきの最適問題を拘束なし問題に簡略化して解くことになる。これで機械学習が一気に難易度を増すことになる。

線形SVM基本形

$$ \underset{w,b}{\operatorname{min}} \frac{1}{2}\boldsymbol{w}^T\boldsymbol{w}\\
s.t. \space \space y_i(\boldsymbol{w}^T \boldsymbol{x_i}+b) \ge 1, \space i=1,2,…,n $$
基本(Primal)形を解くのが難しいから、双対(Dual)形で解くようになる。

線形SVM双対形

$$ \underset{\alpha}{\operatorname{min}} \frac{1}{2} \sum_{i=1}^{n} \sum_{j=1}^{n} \alpha_i \alpha_j y_i y_j \boldsymbol{x}_i^T \boldsymbol{x}_j -\sum_{i=1}^{n} \alpha_i \\
s.t. \space \space \sum_{i=1}^{n} \alpha_i y_i =0 , \alpha_i \ge 0, i=1,2,…,n $$

線形SVM双対形でも分類できない場合、標本\(\boldsymbol{x}\) を \(\phi(\boldsymbol{x})\)で高次元空間\(R^{\bar{d}}\)に写像する。

非線形SVM基本形

線形SVM基本形から、\(\boldsymbol{x}\)を\( \phi(\boldsymbol{x}) \)に置き換えて以下の式になる。
$$ \underset{w,b}{\operatorname{min}} \frac{1}{2}\boldsymbol{w}^T\boldsymbol{w}\\
s.t. \space \space y_i(\boldsymbol{w}^T \phi (\boldsymbol{x_i})+b) \ge 1, \space i=1,2,…,n $$
カーネル関数は写像(Mapping)された高次元空間における内積(Inner Product)\(\small k(\boldsymbol{x}_i,\boldsymbol{x}_j)= \phi(\boldsymbol{x}_i)^T \phi(\boldsymbol{x}_j) \)、基本形を解くのが難しいから、双対形で解くようになる。

非線形SVM双対形

非線形SVM基本形から、\(\boldsymbol{x}\)を\( \phi(\boldsymbol{x}) \)に置き換えて以下の式になる。
$$ \underset{\alpha}{\operatorname{min}} \frac{1}{2} \sum_{i=1}^{n} \sum_{j=1}^{n} \alpha_i \alpha_j y_i y_j \phi(\boldsymbol{x}_i)^T \phi(\boldsymbol{x}_j) -\sum_{i=1}^{n} \alpha_i \\
s.t. \space \space \sum_{i=1}^{n} \alpha_i y_i =0 , \alpha_i \ge 0, i=1,2,…,n $$

ソフト間隔SVM基本形

理論的にはカーネル関数見つけることができるが、現実においては適切なカーネル関数を見つけることは困難な場合あり、なおかつ標本にはノイズが含まれるから、いかに誤分類なしカーネル関数を求めるなら、オーバーフィッティングに陥てしまう。ここで少量の標本に対して誤分類(Miss-Classification)を許す。
$$ \underset{w,b,\xi}{\operatorname{min}} \frac{1}{2}\boldsymbol{w}^T\boldsymbol{w}+C\sum_{i=1}^{n}\xi_i \\
s.t. \space \space y_i(\boldsymbol{w}^T \phi (\boldsymbol{x_i})+b) \ge 1-\xi_i, \space i=1,2,…,n, \\ \xi_i \ge 0, \space i=1,2,…,n $$
間隔を最適化しながら誤分類が許すが、このような誤分類の標本数がなるべく少なくする必要がある。ここでCが間隔と誤分類の標本数のバランスを取る正則化パラメータで、誤分類をどれだけ許すかを決める。Cが大きくすると誤分類の標本数が少なくなり、Cが小さくすると誤分類の標本数が多くなる傾向がある。\(\xi_i\)はスラック変数(Slackness Variable)と言って、誤分類の度合がどれだけあるのか(どれだけ緩めるか)を表すパラメータである。

svm_slackness_variable
svm_slackness_variable

ソフト間隔SVM双対形

$$ \underset{\alpha}{\operatorname{min}} \frac{1}{2} \sum_{i=1}^{n} \sum_{j=1}^{n} \alpha_i \alpha_j y_i y_j \phi(\boldsymbol{x}_i)^T \phi(\boldsymbol{x}_j) -\sum_{i=1}^{n} \alpha_i \\
s.t. \space \space \sum_{i=1}^{n} \alpha_i y_i =0 , 0 \le \alpha_i \le C, i=1,2,…,n $$

ヒンジ損失

前述のスラック変数\(\xi\)に関連してヒンジ損失(Hinge Loss)は以下の式で表す。
$$ l(s) = max(0, 1-s) $$ただし、\(s=y_i(\boldsymbol{w}^T \phi (\boldsymbol{x}))\)は標本が正しく分類されるかの信頼度である。標本が正しく分類される場合、ヒンジ損失=0、標本が誤分類される場合、つまり\(s<0\)の場合、ヒンジ損失は大きくなる。まさにヒンジ開閉のように思わせる。

実装の例

pythonのscikit-learnの機能はsklearnというライブラリに組み込まれる。sklearnにsvmというライブラリが内蔵されているので、非常に便利に利用できる。今回は非線形SVMカーネル関数のガウシアンRBF(Gaussian Radial Basis Function)を利用して標本グループを分類してみよう。

svm_gaussian_kernel
svm_gaussian_kernel

\(k(\boldsymbol{x}_i,\boldsymbol{x}_j)= exp(- \gamma (x_i-x_j)^2)\)より、\(\gamma\)はガウシアン曲面の幅を制御する調整用のパラメーターという。特徴量1のガウス分布曲線を思い出すとイメージが分かる。

ソースコードhttps://github.com/soarbear/Machine_Learning/tree/master/svm

結論

サポートベクターマシンSVMのパラメーター\((w,b)\)は、サポートベクターのみに決定されて、他の標本には関係なし。最大間隔により標本数が少なくとも分類の正確さが高い、汎化能力(Generalization Capability)が優れる、非線形問題が解決できるカーネル関数の導入とのメリットをもつ反面、標本数が大きい場合、カーネル関数の内積の計算とラグランジュ乗数の計算に標本数が関連して多大な計算量につながる。標本数は特徴量次元数より遥かに大きい\(n>>d\)場合、深層学習のほうがSVMより優れる。また適切なカーネル関数を選択するのと、ハイパーパラメータを決めるのも難しい。以上の検討結果を開発に取り組んでいきたい。

参考文献

[1] B. E. Boser, I. M. Guyon, and V. N. Vapnik. A training algorithm for optimal margin classifi ers. In Proceedings of the Annual Workshop on Computational Learning Theory, pages 144–152, 1992.5
[2] S. Boyd and L. Vandenberghe. Convex optimization. Cambridge university press, 2004.4
[3] C.-C. Chang and C.-J. Lin. LIBSVM: A library for support vector machines. ACM Transactions on Intelligent Systems and Technology, 2(3):27, 2011.10
[4] Zhanghao. Derive support vector machine (SVM) from zero. 2018.10

1+

ロボット・ドローン部品お探しなら
ROBOT翔・電子部品ストア

機械学習の4、ロジスティック回帰

概要

本文には、大文字表現=行列/マトリクス、\(\boldsymbol{Bold}\)小文字表現=d次元(特徴量)ベクトル\(R^d\)、普通小文字表現=スカラー\(R\) と記す。

ロジスティック回帰(Logistic Regression)は、sigmoid関数を介して線形回帰\((\boldsymbol{w}^T\boldsymbol{x}+b)\)を\(0|1\)にマッピングして2クラス問題に適用される。Nクラス分類問題の場合、\((\boldsymbol{w}^T\boldsymbol{x}+b)\)のNセットを取得してから例えばsoftmax関数で振り分ける多重分類問題に適用される。「回帰」の表現があるのに、実はロジスティック分類である。線形回帰\((\boldsymbol{w}^T\boldsymbol{x}+b)\)を利用したわけではないかと考えられる。

線形回帰の表現

$$ f(\boldsymbol{x}) = \boldsymbol{w}^T\boldsymbol{x} + b $$
ただし、\(\boldsymbol{w},\boldsymbol{x} \in R^d,\space b\in R, \boldsymbol{w}’=[b \space \space \boldsymbol{w}]^T, \boldsymbol{x}’ =[1 \space \space \boldsymbol{x}]^T \)にすると、上式が以下の式に簡略化される。
$$ f(\boldsymbol{x}’) = \boldsymbol{w’}^T \boldsymbol{x}’ → f(\boldsymbol{x}) = \boldsymbol{w}^T x $$

分類方法

Perception(認知)機能を果たす活性化関数がsigmoid関数で\(0|1\)の2クラスに分類する場合、以下の表現がある。
$$ p = \sigma[f(\boldsymbol{x})] = \frac{1}{1+e^{-f(\boldsymbol{x})}} = \frac{1}{1+e^{-\boldsymbol{w}^T\boldsymbol{x}}} $$
ただし、\(f(\boldsymbol{x}) = \boldsymbol{w}^T\boldsymbol{x}, \sigma=\)sigmoid関数、区間(0,1)に入る。

コスト関数

$$ c(\boldsymbol{w}) = ln[\prod_{i=0}^{n-1} p^{1-y_i}(1-p)^{y_i}] $$ただし、\(y_i=0|1,p=\sigma[f(\boldsymbol{x})]\)
上式コスト関数、\((p|(y_i=0|1)\)の総乗を最大にすると、つまり最尤推定法を適用する。最尤推定(=コスト関数の最大値)には、勾配を利用してコスト関数が早く収束するつまり最大値を求める方法で、\( \boldsymbol{w}^*= \underset{\boldsymbol{w}}{\operatorname{arg max}}(c(\boldsymbol{w})) \)を求めて、分類の境界線を決めることになる。

※ 最尤法あるいは最尤推定法とは任意のある観測値に対して確率を最大にする母数の推定値を求めようとする方法。

コスト関数の勾配

複数標本\((\sum^n)\)に対して、上記コスト関数が\(\boldsymbol{w}\)への勾配は以下の式で求める。
$$ \nabla c(\boldsymbol{w}) = \sum_{i=0}^{n-1}(y_i-p_i) \boldsymbol{x}_i $$
sigmoid関数を選んだかいで、勾配が簡潔に表現可能となる。

勾配上昇

$$ \boldsymbol{w}^*= \underset{\boldsymbol{w}}{\operatorname{arg max}}[c(\boldsymbol{w})] → \boldsymbol{w}_{t}=\boldsymbol{w}_{t-1}+\alpha \nabla c(\boldsymbol{w_{t-1}}) $$
ただし、\(\alpha\)は小さい常数、例えば0.01~0.001にする。指定した回数または\(\boldsymbol{w}\)の変化しなくなるまで\(\boldsymbol{w}\)の更新を繰り返す。\(\boldsymbol{w}\)が\(c(\boldsymbol{w})\)の勾配方向へ繰り返すので\(c(\boldsymbol{w})\)が素早く最大に達する。最後の\(\boldsymbol{w}=\boldsymbol{w}^*\)と見なして、つまりロジスティック回帰の分類器ができてしまう。

ランダム勾配上昇

上記勾配上昇をさらに工夫して、\(\nabla c(\boldsymbol{w})\)をランダム関数\(g(\boldsymbol{w})\)に置き換えると速くコスト関数cの最大値に辿り着く場合がある。勿論、この置き換え関数の期待値がを勾配に等しいのを満たす必要があり、勾配値の周りにランダムな変動に相当する。
$$ \boldsymbol{w}_{t}=\boldsymbol{w}_{t-1}+\alpha g(\boldsymbol{w_{t-1}}) $$

分類実施

\(\boldsymbol{w}\)を分かると、前述分類方法で、\(\boldsymbol{x}\)を分類する。
$$ p = \sigma[f(\boldsymbol{x})] \ge 0.5 → Class 1, \space A ,\space etc \\
p = \sigma[f(\boldsymbol{x})] < 0.5 → Class 0,\space B ,\space etc $$

実装の例

クラス\(=2\)、特徴量\(=2\)、標本数\(=m+n\)の標本組、ポイントグループ\(\small A([m,2]|y_i=1)=[(x1_0,x2_0),(x1_1,x2_1),…,(x1_m,x2_m)]\)と、ポイントグループ\(\small B([n-m,2]|y_i=0)=[(x1_{m+1},x2_{m+1)}),…,(x1_n,x2_n)]\)の分類に使わられる境界線\(y=\boldsymbol{w}^T\boldsymbol{x}\)の係数を求めて、境界線を描く。ただし、\(\small \boldsymbol{w},\boldsymbol{x}\in R^d, y\in \{0,1\},\)\(\small \boldsymbol{w}=[b \space \space w1 \space \space w2]^T, \boldsymbol{x}=[1 \space \space x1 \space \space x2]^T\)

ソースコードhttps://github.com/soarbear/Machine_Learning/tree/master/logistic_regression

結果

logistic_regression
logistic_regression

標本の増減により、境界線を計算しなおすのと、正確さから、SVMなどもっと進化した方法論がある。

参考文献

「Machine Learning in Action」、Peter Harrington氏

1+

ロボット・ドローン部品お探しなら
ROBOT翔・電子部品ストア

機械学習の3、単純ベイズ

概要

単純ベイズ分類器(Classifier)はベイズ定理(Bayes’ theorem)の条件付き確率(事後確率、事前確率)による分類器また分類法である。
$$ \small P(Class | F_1, F_2,,,F_n) = \frac{P(F_1, F_2,,,F_n | Class) P(Class)}{P(F_1, F_2,,,F_n)} $$
\( P(Class | F_1, F_2,,,F_n) \)は事後確率、\( P(F_1, F_2,,,F_n | Class) \)は事前確率で、P(Class)とも先に求められるとされる。{P(F_1, F_2,,,F_n)は分類空間の類(Class)同士に変わらないので、計算が省略可能となる。

また特徴空間にある特徴(Feature)同士がそれぞれ相関しないことに簡略化して、単純ベイズ分類器となる。
$$ \scriptsize P = \frac{P(F_1|Class)P(F_2|Class)…P(F_n|Class)|Class) P(Class)}{P(F_1, F_2,,,F_n)} $$

前述kNN、kd-treeとも分類法だが、条件付き確率のベイズ定理を生かした、計算量が少ない確率論らしい分類法である。

実装の例

訓練データセットと単純ベイズ分類器により、ある言葉リストから暴力傾向あるかどうかを判別する。

'''
Created on Oct 19, 2010
@author: Peter
'''
from numpy import *
#
# Load postingList -> myVocabList.
#
def loadDataSet():
  postingList=[['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
               ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
               ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
               ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
               ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
               ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
  classVec = [0,1,0,1,0,1]    #1 is abusive, 0 not
  return postingList,classVec
#
# Create myVocabList from postingList.
#
def createVocabList(dataSet):
  vocabSet = set([])  #create empty set
  for document in dataSet:
      vocabSet = vocabSet | set(document) #union of the two sets
  return list(vocabSet)
#
# Create 2 vectors(2 categories) from myVocabList.
#
def setOfWords2Vec(vocabList, inputSet):
  returnVec = [0]*len(vocabList)
  for word in inputSet:
      if word in vocabList:
          returnVec[vocabList.index(word)] = 1
      else: print(f"the word: {word} is not in my Vocabulary!")
  return returnVec
#
# Create 2 conditional probability vectors(p0Vect & p1Vect) from trainMatrix.
# Create 1 category probabilitiesies(pAbusive) from trainCategory.
#
def trainNB0(trainMatrix,trainCategory):
  numTrainDocs = len(trainMatrix)
  numWords = len(trainMatrix[0])
  pAbusive = sum(trainCategory)/float(numTrainDocs)
  p0Num = ones(numWords); p1Num = ones(numWords)      #change to ones() 
  p0Denom = 2.0; p1Denom = 2.0                        #change to 2.0
  for i in range(numTrainDocs):
      if trainCategory[i] == 1:
          p1Num += trainMatrix[i]
          p1Denom += sum(trainMatrix[i])
      else:
          p0Num += trainMatrix[i]
          p0Denom += sum(trainMatrix[i])
  p1Vect = log(p1Num/p1Denom)          #change to log()
  p0Vect = log(p0Num/p0Denom)          #change to log()
  return p0Vect,p1Vect,pAbusive
#
# Perform bayes formula. 1: testEntry is abusive, 0: not abusive.
#
def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):
  p1 = sum(vec2Classify * p1Vec) + log(pClass1)    #element-wise mult
  p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)
  if p1 > p0:
      return 1
  else: 
      return 0
#
# Test 2 testEntries.
#  
def testingNB():
  listOPosts,listClasses = loadDataSet()
  myVocabList = createVocabList(listOPosts)
  trainMat=[]
  for postinDoc in listOPosts:
      trainMat.append(setOfWords2Vec(myVocabList, postinDoc))
  p0V,p1V,pAb = trainNB0(array(trainMat),array(listClasses))
  testEntry = ['love', 'my', 'dalmation']
  thisDoc = array(setOfWords2Vec(myVocabList, testEntry))
  print(f"{testEntry}, 'classified as:' {classifyNB(thisDoc,p0V,p1V,pAb)}")
  testEntry = ['stupid', 'garbage']
  thisDoc = array(setOfWords2Vec(myVocabList, testEntry))
  print(f"{testEntry}, 'classified as:' {classifyNB(thisDoc,p0V,p1V,pAb)}")
#
# main.
#
if __name__ == '__main__':
  testingNB()

ソースコード→https://github.com/soarbear/Machine_Learning/tree/master/bayes

結果

bayes_result
bayes_result

参考文献

「Machine Learning in Action」、Peter Harrington氏

追記

単純ベイズは、回帰に使われた報道がある。
参考文献 Technical Note:Naive Bayes for Regression

0

ロボット・ドローン部品お探しなら
ROBOT翔・電子部品ストア

深層学習の4、RNN LSTM stateful

概要

RNNやLSTMのstatefulとは、各Batchのサンプルの状態が次のバッチのサンプルのための初期状態として再利用されるのであれば、Trueに設定するという。計算の便宜を図ってbatch_size毎に分けて処理するが、株価、天気など時系列データに対して、状態を引き継ぐのであれば、stateful=Trueのケースに当たる。stateful=Trueの場合、以下で対応する。

・batch_size引数をモデルの最初のLayerに渡して、batch_sizeを明記する。例えば、サンプル数が32、time_stepsが10、input_dim(入力データの特徴数)が16の場合には,batch_size=32と明記する。
・RNN、LSTMでstateful=Trueに指定する.
・model.fit()を呼ぶときにshuffle=Falseに指定する。
・model.reset_states()を実行して、モデルの全てのLayerの状態を更新する。または、layer.reset_states()を実行して、特定のLayerの状態を更新する。
・model.predict()、model.train_on_batch()、model.predict_classes()関数はいずれもmodel.reset_states()を実行してstatefulに指定したLayerの状態を更新する。
・model.fit()を呼ぶときに、以下例のようにbatch毎にmodel.reset_states()を実行する。

注意点は、model.reset_states()は0にリセットではなく、前回のbatchの最終状態値を次回のbatchの初期状態値として引き継ぐ。

実装の例

print('Build STATEFUL model...')
model = Sequential()
model.add(LSTM(10, batch_input_shape=(1, 1, 1), return_sequences=False, stateful=True))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
...
class ResetStatesCallback(Callback):
    def __init__(self):
        self.counter = 0

    def on_batch_begin(self, batch, logs={}):
        if self.counter % max_len == 0:
            self.model.reset_states()
        self.counter += 1
...        
model.fit(x, y, callbacks=[ResetStatesCallback()], batch_size=1, shuffle=False)
...
print('Train...')
for epoch in range(15):
    for i in range(len(X_train)):
        tr_loss, tr_acc = model.train_on_batch(np.expand_dims(np.expand_dims(X_train[i][j], axis=1), axis=1),
                                               np.array([y_true]))
        model.reset_states()
...
    for i in range(len(X_test)):
        te_loss, te_acc = model.test_on_batch(np.expand_dims(np.expand_dims(X_test[i][j], axis=1), axis=1),
                                              y_test[i])
        model.reset_states()
...
        y_pred = model.predict_on_batch(np.expand_dims(np.expand_dims(X_test[i][j], axis=1), axis=1))
        model.reset_states()

【追記】

model.fit()のcallbackについて、以下アラームが出る。

RuntimeWarning: Method (on_train_batch_begin) is slow compared to the batch update. Check your callbacks.

これは、model.fit()のcallbackがbatch updateより遅いという、なおかつ、KERASの解釈では「RNNをstatefulにするとは,各バッチのサンプルの状態が,次のバッチのサンプルのための初期状態として再利用されるということを意味する」とあるので、model.fit()を以下のように修正する。

for i in range(NUM_EPOCHS):
    for j in range(BATCH_SIZE*NUM_BATCHES):
        model.fit(x, y, batch_size=1, epochs=1, shuffle=False)
        model.reset_states()

ただし、batchが順番に処理してGPU/TPUの並行処理ができなくなるので、処理時間が余儀なく伸びる。精度と処理時間の兼ね合いから課題として、stateless、batch_sizeを増やすのと、statefulにするのと、どちらがよいのかを案件ごとに実験することが大切だと考える。

参考文献

Keras FAQ
・「Deep Learning with Keras」、by Antonio Gulli氏

1+

ロボット・ドローン部品お探しなら
ROBOT翔・電子部品ストア

機械学習の2、kd-tree最近傍探索

概要

kd-tree(k-dimensions tree)はk次元空間の分割を表す二分木である。kd-treeの構築は、座標軸に垂直な超平面でk次元空間を連続的に分割して、一連のk次元超長方形領域を形成する。 kd-treeの各ノードは、k次元の超長方形領域に対応する。kd-treeを使用すると、一部のインスタンスとの計算はしないため、計算量が削減される。よって、kd-treeは早く探索できるようにインスタンスポイントをk次元空間に格納するツリー型のデータ構造である。それで新しいインスタンスのクラス(属するクラス)をkd-treeによる判明する最近傍探索のアルゴリズムは以下のとおり。
【step1】k次元データセットで最大の分散をもつ次元xを選択し、次に中央値mを次元の中間点として選択してデータセットを分割し、2つのサブセットを取得する。
【step2】2つのサブセットに対してステップstep1のプロセスを繰り返して、すべてのサブセットが再分割できなくなるまで繰り返す。
【step3】step1とstep2で根から葉っぱまで木ができてしまう。
【step4】根から葉っぱまで、あるルール「\(x \lt m\)→左枝、\(x \geq m\)→右枝」で新しいインスタンスに最も近いインスタンス(葉っぱ、ノード)をみつける。
【step5】step4の葉っぱから、逆方向で根まで遡って、step4よりもっと近いインスタンスがあるか探索する。要するに、新しいインスタンスにもっとも近い既存インスタンスをみつける。
【step6】step4またはstep5でみつけたもっとも近い既存インスタンスのクラスは結果となる。

実装の例

下図から、グリーンポイント[2, 4.5]に最近傍ポイントを探索する。

kd_tree_newPoint
kd_tree_newPoint

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
#from time import time

#
# load data from dataset file. 
#
def load_data(fileName):
    data_mat = []
    with open(fileName) as fd:
        for line in fd.readlines():
            data = line.strip().split()
            data = [float(item) for item in data]
            data_mat.append(data)
    data_mat = np.array(data_mat)
    label = data_mat[:, 2]
    data_mat = data_mat[:, :2]
    return data_mat, label
#
# create a tree for dataset.
#
def create_kdtree(dataset, depth):
    n = np.shape(dataset)[0]
    tree_node = {}
    if n == 0:
        return None
    else:
        n, m = np.shape(dataset)
        split_axis = depth % m
        depth += 1
        tree_node['split'] = split_axis
        dataset = sorted(dataset, key=lambda a: a[split_axis])
        num = n // 2
        tree_node['median'] = dataset[num]
        tree_node['left'] = create_kdtree(dataset[:num], depth)
        tree_node['right'] = create_kdtree(dataset[num + 1:], depth)
        return tree_node
#
# search k near points on the tree. 
#
def search_kdtree(tree, data):
    k = len(data)
    if tree is None:
        return [0] * k, float('inf')
    split_axis = tree['split']
    median_point = tree['median']
    if data[split_axis] <= median_point[split_axis]:
        nearest_point, nearest_distance = search_kdtree(tree['left'], data)
    else:
        nearest_point, nearest_distance = search_kdtree(tree['right'], data)
    
    # the distance between data to current point.
    now_distance = np.linalg.norm(data - median_point)
    if now_distance < nearest_distance:
        nearest_distance = now_distance
        nearest_point = median_point.copy()
        
    # the distance between hyperplane.
    split_distance = abs(data[split_axis] - median_point[split_axis])
    if split_distance > nearest_distance:
        return nearest_point, nearest_distance
    else:
        if data[split_axis] <= median_point[split_axis]:
            next_tree = tree['right']
        else:
            next_tree = tree['left']
        near_point, near_distance = search_kdtree(next_tree, data)
        if near_distance < nearest_distance:
            nearest_distance = near_distance
            nearest_point = near_point.copy()
        return nearest_point, nearest_distance
#
# main.
#
if __name__ == '__main__':
    data_mat, label = load_data('/content/drive/My Drive/Colab Notebooks/Machine_Learning/kd_tree/dataset.txt')
    fig = plt.figure(0)
    ax = fig.add_subplot(111)
    ax.scatter(data_mat[:, 0], data_mat[:, 1], c=label, cmap=plt.cm.Paired)

    new_point = [2, 4.5]
    kdtree = create_kdtree(data_mat, 0)
    print(kdtree)
    #start = time()
    nearest_point, near_dis= search_kdtree(kdtree, new_point)
    #print(time()-start)
    ax.scatter(new_point[0], new_point[1], c='g', s=50)
    ax.scatter(nearest_point[0], nearest_point[1], c='r', s=50)
    plt.show()

ソースコード→https://github.com/soarbear/Machine_Learning/tree/master/kd_tree

結果

kd_tree_findNearestPoint
kd_tree_findNearestPoint

kd-treeが小さい(例えば\(k=20\))場合、アルゴリズムの探索効率はkNNより非常に高くなる。ただし、データの次元が増えると(例えば、\(k≥100\))、探索効率が急速に低下する。データセットの次元がkであると仮定すると、効率的な探索を実現するために、データサイズNが、\(N >> 2^k\)を満たすことが必要である。kd-treeで高次元データにも適用できるように、Jeffrey S. BeisとDavid G. Loweは、改善されたアルゴリズムKd-tree with BBF(Best Bin First)の提案に貢献した。ある探索精度を確保するという前提で探索を高速化するようになる。

参考文献

「Machine Learning in Action」、Peter Harrington氏

1+

ロボット・ドローン部品お探しなら
ROBOT翔・電子部品ストア