F# Implementation of BackPropagation Neural Network for Pattern Recognition(LifeGame)
この記事は、F# Advent Calendar 2011の21日目です。
きっかけは、11月19日に札幌で行われた第64回CLR/H勉強会で、愛甲健二さん(@07c00)がお話してくれた「コンピューターに萌えを教えてみたよ」というセッションです。「アダルトサイトの検知」のメカニズムだったり、愛甲さん自身の"萌えの嗜好"をコンピューターに学習させてみるという少しアレゲなテーマでのお話しでしたが、内容はとても真面目で面白かった。見慣れない数式など、その全てを理解することはできませんでしたが、ニューラルネットワークの雰囲気や概要がわかりました。オライリーの「集合知プログラミング」でニューラルネットワークについて少し読んだことがあったり、何となく見聞きしたことはありましたが、基本的な考え方を知ったのはそのときがはじめてです。とても面白くもっと知りたいと思ったので、勉強会の後にモクモクとニューラルネットワークに関する情報を集めて自分なりに勉強してみました。"脳を模してモデル化したアルゴリズムによって、コンピュータに学習能力をもたせる。" なんだか面白かっこいい!じゃないですか。いろいろと調べているうちに、これなら自分にも実装できそう!と思ったので、みんな大好きF#でやってみました。F#の記事というよりも、むしろニューラルネットワーク成分多い目だが、
「大丈夫だ、ゆるふわなので問題ない。」
ニューラルネットワークとは
情報分野におけるニューラルネットワークとは、われわれ人間の脳の神経回路の仕組みを模してモデル化したもので、コンピュータに学習能力を持たせることで、様々な問題を解決しようとするアプローチのひとつで、人工知能の一分野で機械学習というジャンルに属します。もともとニューラルネットワークという研究分野は、人間が自然と行っているパターン認識や経験則を導き出したりする仕組みをモデル化して、ロボットが経験から学習していくことで、正しい反応や行動を獲得していく仕組みを実現することを目的とした側面が強かったようですが、次第に工学寄りにシフトしてきて、「データの中で明らかなものから、明らかではないものを予測する(ことをコンピュータにやらせるための)」技術や理論を指すことがほとんどになってきたようです。近年の自然言語処理や画像のパターン認識、データマイニング、あるいは信用リスク格付け予測など、ビジネス用途での応用分野における成功を要因に、普及と発展が進んでいて現在も広くその研究や応用が進められている。
教師あり学習というアプローチ
機械学習の扱う問題には、大きく分けて教師あり学習 (supervised learning) と、教師なし学習 (unsupervised learning) の2つがある。 単純にその2つに分類することができない、複合的な問題や独自に発展した特殊問題もあるようですが、基本的には、この2つに分類することができる。愛甲さんがお話してくれた、アダルトサイトの検知だったり、「コンピューターに萌えを教えてみたよ」は、ちょうど教師あり学習にあたります。教師あり学習では、入力データ(条件として明らかとなっている情報)が与えられたとき、これに対する出力(答えが明らかではない情報)を正しく予測することが目的です。 もちろん、ただ入力を入れられただけでは、コンピューターは答えとして何を出力したらよいのかわかりません。そこで、訓練データ(あるいは教師データ)と呼ばれる、入出力のペアとしたデータを、あらかじめコンピューター複数与えます。「コレの入力があったら、コレを出力しなさい」というパターンをいくつか与えて機械に学習させます。新しい入力データが来たときに、それに対する正しい出力をするような機械(関数)を作るのが目的です。複雑で広い領域の問題では、すべてのパターンを機械に学習させることは不可能で、当然、あらかじめ学習に用いる訓練データの中には現れない入力データが与えられる場合もあります。そのようなデータに対応するために、与えられた訓練データを一般化して、未知のデータに対処して予測を出力する能力(汎化能力)がなるべく高くなるような、学習アルゴリズムを設計することが、教師あり学習の主要なテーマとなります。ニューラルネットワークは、汎化能力の高い教師あり学習のアプローチのひとつです。
F#でニューラルネットワーク
F#でバックプロパゲーションアルゴリズムを用いた3層パーセプトロンを実装しました。時間がなくて整理しきれなかった部分があり心残りな面もありますが、以下、NNモジュールです。参考になればと思い、普段は書かないような説明的なコメントも多めに書いてみました。
F#
namespace NN open System [<AutoOpen>] module NN = /// 訓練データパターン type Pattern = { Inputs : double list; (* 入力 *) TeachingSignal: double list (* 教師信号 *) } // 層をつくる let createLayer size build = let rec create size acc = if size <= 0 then acc else create ((-) size 1) (acc@[build ()]) create size [] /// シグモイド関数 /// 関数のある点での勾配を求めて誤差Eが少なくなる方向へ結合重みWを変化させていきます。 let sigmoid input bias = /// α(gain)を1.0とするとき標準シグモイド関数と言う let gain : double = 5.0 1.0 / (1.0 + Math.Exp(-gain * (input + bias))) /// ニューロン type Neuron = { mutable bias : double // バイアス mutable error : double // E mutable input : double // 入力 mutable output : double // 出力 learnRate : double // 学習レート weights : Weight list // 重み } /// 出力 member this.Output with get () = if (this.output <> Core.double.MinValue) then this.output else // 判別問題を学習させる場合は階段関数やシグモイド関数を用いる。回帰問題を学習させる場合は線形関数を用いる。 // 今回はシグモイドで sigmoid this.input this.bias and set (v) = this.output <- v // 重み and Weight = { In: Neuron; mutable Value:double } // 層 and Layer = Neuron list /// 活性化 let activate neuron = neuron.input <- 0.0 for w in neuron.weights do neuron.input <- neuron.input + w.Value * w.In.Output /// エラーフィードバック let errorFeedback (neuron:Neuron) (input:Neuron) = neuron.Output * (1.0 - neuron.Output) |> fun derivative -> // より大きな重みで接続された前段のニューロンに対して、局所誤差の責任があると判定する。 let w = neuron.weights |> List.find (fun t -> t.In = input) neuron.error * derivative * w.Value /// 各ニューロンの重みを局所誤差が小さくなるよう調整する。 let adjustWeights (neuron:Neuron) (value:double) = neuron.Output * (1.0 - neuron.Output) |> fun derivative -> neuron.error <- value for i in [0..neuron.weights.Length-1] do // 出力と教師信号が異なれば、出力値を少しだけ教師信号寄りに重みを修正する neuron.weights.[i].Value <- neuron.weights.[i].Value + (neuron.error * neuron.learnRate * derivative * neuron.weights.[i].In.Output) // バイアスの補正 neuron.bias <- neuron.bias + neuron.error * neuron.learnRate * derivative /// 素のニューロンを生成 let createNewron () = { bias = 0.0 error = 0.0 input = 0.0 output = Core.double.MinValue learnRate = 0.5 weights = [] } /// 入力についてランダムな重みを持つニューロンを生成 let createNewron' inputs (rnd:Random) = let createWeights () = inputs |> List.map (fun input -> { In = input; Value = rnd.NextDouble() * 2.0 - 1.0 }) |> List.fold (fun a b -> a@[b]) [] { bias = 0.0 error = 0.0 input = 0.0 output = Core.double.MinValue learnRate = 0.5 weights = createWeights () } /// ネットワーク type Network = { InputSize : int MiddleSize : int OutputSize : int RestartAfter : int TryCount : int Inputs : Layer Middle : Layer Outputs : Layer Patterns : Pattern list } /// 入力層、中間層、出力層のニューロンを生成 let createNeuron inputSize middleSize outputSize = let rnd = new Random() let inputs = createLayer inputSize (fun () -> createNewron ()) let middle = createLayer middleSize (fun () -> createNewron' inputs rnd) let outputs = createLayer outputSize (fun () -> createNewron' middle rnd) inputs, middle, outputs /// ニューラルネットワークの各ニューロンを活性化 let networkActivate (network:Network) (pattern:Pattern) = for i in [0..pattern.Inputs.Length - 1] do network.Inputs.[i].Output <- pattern.Inputs.[i] for neuron in network.Middle do activate neuron for output in network.Outputs do activate output network.Outputs |> List.map (fun output -> output.Output) /// 初期化 let initializeNetwork (network:Network) = let inputs,middle,outputs = createNeuron network.InputSize network.MiddleSize network.OutputSize { network with Inputs = inputs; Middle = middle; Outputs = outputs; TryCount = 0 } /// 訓練データをNetworkに読み込む let loadPatterns (network:Network) (trainingData :(double list * double list) list) = let rec create n acc = if n <= 0 then acc else let inputs,teachingSignal = trainingData.[n] create ((-) n 1) (acc@[{Inputs=inputs; TeachingSignal=teachingSignal}]) { network with Patterns = create (trainingData.Length-1) [] } /// 訓練 let training (network:Network) = /// 重み調整:バックプロパゲーション let adjustWeights (delta:double) = // 個々のニューロンの期待される出力値と倍率(scaling factor)、要求された出力と実際の出力の差を計算する。これを局所誤差と言う。 for output in network.Outputs do adjustWeights output delta for neuron in network.Middle do // そのように判定された前段のニューロンのさらに前段の中間層における隠れニューロン群について同様の処理を行う。 adjustWeights neuron (errorFeedback output neuron) let mutable error = 0.0 for pattern in network.Patterns do // ネットワークの出力とそのサンプルの最適解を比較する。各出力ニューロンについて誤差を計算する。 for i in [0..pattern.TeachingSignal.Length - 1] do let output = (networkActivate network pattern).[i] let delta = pattern.TeachingSignal.[i] - output adjustWeights delta // 二乗誤差でEを求める error <- error + Math.Pow(delta, 2.0) { network with TryCount = network.TryCount + 1}, error /// 三層ネットワークを生成 let createNetwork (inputs:Layer) (middle:Layer) (outputs:Layer) restartAfter = { InputSize = inputs.Length MiddleSize = middle.Length OutputSize = outputs.Length TryCount = 0 RestartAfter = restartAfter Inputs = inputs Middle = middle Outputs = outputs Patterns = [] }
線形分離問題「OR」および「AND」、非線形分離問題 XORを解く
以下、NNモジュールを使って各問題を解くF#
open System open NN open ListExModule [<STAThread>] // 三層分のニューロンを生成 let inputs,middle,outputs = createNeuron 2 3 1 // ニューラルネットワークを構築 let mutable (network:Network,error:float) = createNetwork inputs middle outputs 500 , 1.0 let rec flat = function | [] -> [] | x::_ when x = [] -> [] | x::xs -> x @ flat xs let rec insert v i lst = match i, lst with | 0, xs -> v::xs | i, x::xs -> x::insert v (i - 1) xs | i, [] -> failwith "境界外デス!" let condition = [1..8] let createPattern target ts (source: int list) = let inputs = condition |> List.map (fun i -> if source |> List.exists (fun x -> x = i) then 1.0 else 0.0) |> insert (if target = 1 then 1.0 else 0.0) 4 inputs,[ts] // AND問題 (線形分離可能) let andProblem = [ [0.0; 0.0;], [0.0]; [0.0; 1.0;], [0.0]; [1.0; 0.0;], [0.0]; [1.0; 1.0;], [1.0]; ] // OR問題 (線形分離可能) let orProblem = [ [0.0; 0.0;], [0.0]; [0.0; 1.0;], [1.0]; [1.0; 0.0;], [1.0]; [1.0; 1.0;], [1.0]; ] // XOR問題 (線形分離不可能) let xorProblem = [ [0.0; 0.0;], [0.0]; [0.0; 1.0;], [1.0]; [1.0; 0.0;], [1.0]; [1.0; 1.0;], [0.0]; ] // 訓練データをロード network <- loadPatterns network xorProblem // ここではXORを解く let main () = /// 実行 let run (network:Network) = while true do try Console.Write("Input x, y: ") let values = Console.ReadLine() let line = values.Split(',') let pattern = [0..network.InputSize-1] |> List.map (fun i -> Core.double.Parse(line.[i])) let inputs = List.init(network.InputSize) (fun i-> pattern.[i]) for output in networkActivate network { Inputs=inputs; TeachingSignal = []} do printfn "%d(%f)" <| Convert.ToInt32(output) <| output with | e -> Console.WriteLine(e.Message) // ニューラルネットワークを訓練する while error > 0.1 do let x,y = training network network <- x; error <- y printfn "Try %d\tError %f" x.TryCount y if network.TryCount > network.RestartAfter then network <- initializeNetwork network // 実行 run network main () Console.ReadKey () |> ignore
非線形分離問題も問題なく解けますな。
パターン認識でライフゲーム
バックプロパゲーションアルゴリズムで3層パーセプトロンによって構築したニューラルネットでXOR判定をすることができた。ここで終わってもよかったのですが、せっかくなので欲張って、もう少しだけ複雑な非線形問題のパターン認識もやらせてみました。第64回CLR/H勉強会の、@mentaroさんのセッションの最終デモで「ライフゲーム」が取り上げられていました。勉強会後に、「そういや、ライフゲームのセル生死判定は、判定対象セルとその周囲8つのセルをパターンとして捉えることがきて、セルの生死結果を教師データとするパターンをつくって、多数の訓練データで学習させることで、ニューラルネットワークにライフゲームの生死判定をさせることができるんじゃね?」と思いました。それを実践してみようという。練習にはちょうど良いですね。判定対象セルと周囲の8つのセルを合わせた9つのセルを入力とし、生死の結果を教師データとする訓練データを作成して、ニューラルネットに食わせてシバけばおーけー!
以下、NNモジュールを使って、
F#+XNAで、ニューラルネットのパターン認識でライフゲームなコード
F#
namespace LG open System open Microsoft.Xna.Framework open Microsoft.Xna.Framework.Graphics open Microsoft.Xna.Framework.Input open Microsoft.Xna.Framework.Content open NN open ListExModule [<AutoOpen>] module Assist = // リスト平坦化 let rec flat = function | [] -> [] | x::_ when x = [] -> [] | x::xs -> x @ flat xs // リストへの挿入 let rec insert v i lst = match i, lst with | 0, xs -> v::xs | i, x::xs -> x::insert v (i - 1) xs | i, [] -> failwith "境界外デス!" let condition = [1..8] // パターン生成 let createPattern target ts (source: int list) = let inputs = condition |> List.map (fun i -> if source |> List.exists (fun x -> x = i) then 1.0 else 0.0) |> insert (if target = 1 then 1.0 else 0.0) 4 inputs,[ts] // ライフゲームの教師データ生成 let lifeGameTrainingData = let pattern = [0..8] |> List.map (fun x -> combinations x condition) let survive = List.map (fun x -> x |> createPattern 1 1.0) // 生存 let keep = List.map (fun x -> x |> createPattern 0 0.0) // 維持 let birth = List.map (fun x -> x |> createPattern 0 1.0) // 誕生 let die = List.map (fun x -> x |> createPattern 1 0.0) // 過疎or過密 pattern |> List.mapi (fun i x -> i |> function | 2 -> survive x @ keep x | 3 -> survive x @ birth x | _ -> die x @ keep x) |> flat /// 初期ボード:グライダー銃 let getGliderguns () = [|[|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;1;0;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;1;1;0;0;0;0;0;0;1;1;0;0;0;0;0;0;0;0;0;0;0;0;1;1;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;1;0;0;0;1;0;0;0;0;1;1;0;0;0;0;0;0;0;0;0;0;0;0;1;1;0;0;0;0;0;0|]; [|0;0;1;1;0;0;0;0;0;0;0;0;1;0;0;0;0;0;1;0;0;0;1;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;1;1;0;0;0;0;0;0;0;0;1;0;0;0;1;0;1;1;0;0;0;0;1;0;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;1;0;0;0;0;0;1;0;0;0;0;0;0;0;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;1;0;0;0;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;1;1;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]; [|0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0;0|]|] /// 同じ長さを持つジャグ配列を二次元配列へ変換 let convert (source:int [][]) = (source.[0].GetLength(0),Array.length source) ||> fun row col -> Array2D.create row col 0 |> fun array -> seq { for i in 0..row - 1 do for j in 0..col - 1 do yield i,j } |> Seq.iter (fun (i,j) -> array.[i,j] <- source.[j].[i]) array /// ライフゲーム type LifeGame () as this = inherit Game() // ゲームタイトル, GraphicsDeviceManager, SpriteBatch let gametitle, gmanager, spriteBatch = "LifeGame", new GraphicsDeviceManager(this), lazy new SpriteBatch(this.GraphicsDevice) // 三層パーセプトロンの各ニューロンを生成 -> 入力:9 , 隠れ:17 , 出力:1 let inputs,middle,outputs = createNeuron 9 17 1 // ニューラルネットワークを構築, error状態を取得 let mutable (network:Network,error:float) = createNetwork inputs middle outputs 500 , 1.0 // SpriteFont let font = lazy this.Content.Load<SpriteFont>(@"Content\font\SpriteFont1") // セルのテクスチャ let textureCell = lazy this.Content.Load<Texture2D>(@"Content\hagure") // セルエフェクト用マスクテクスチャ let normalmapTextureCell = lazy this.Content.Load<Texture2D>(@"Content\hagure_alpha") // HLSLエフェクト let normalmapEffect = lazy this.Content.Load<Effect>(@"Content\normalmap") // セルとセルの間の間隔 let borderWidth, borderHeight = 0, 0 // セル描画の開始位置 let boardStartX, boardStartY = 0, 0 // セルのサイズ(テクスチャのサイズによりけり) let cellWidth, cellHeight = 18, 17 // ライフゲームの状態を表すボード let board = getGliderguns() |> convert // ボードのサイズ let width, height = Array2D.length1 board , Array2D.length2 board // ライフゲームの状態の更新制御 let mutable runFlg = true let mutable nowRunFlg = false let mutable previousRunFlg = false // ライフゲームの世代交代インターバル let mutable interval = 10.0 // マウスボタンのリリース状態 let mutable mouseButtonReleased = false // 訓練が終了したか否か let mutable trainingEnd = false // マウスクリック位置の取得 let getPos x y = new Vector2(float32(boardStartX + x * cellWidth + x * borderWidth), float32(boardStartY + y * cellHeight + y * borderHeight)) // セル描画 サークル動作 let moveInCircle (gameTime:GameTime) (speed:float) = let time = gameTime.TotalGameTime.TotalSeconds * speed let x = Math.Sin(time) |> float32 let y = Math.Cos(time) |> float32 new Vector2(x, y) // キー操作 let operateKeys () = let mouseState = Mouse.GetState() let keyboardState = Keyboard.GetState() if mouseState.LeftButton = ButtonState.Pressed && mouseButtonReleased && this.IsActive then // マウスボタン押下中 mouseButtonReleased <- false let mouseStateX, mouseStateY = mouseState.X |> float32, mouseState.Y |> float32 let mousePos = new Vector2(mouseStateX, mouseStateY ) for x in [0..width-1] do for y in [0..height-1] do let pos = getPos x y if pos.X < mousePos.X && pos.X + float32(cellWidth) > mousePos.X && pos.Y < mousePos.Y && pos.Y + float32(cellHeight) > mousePos.Y then // マウスでクリックされたところのセルの生死状態のトグル board.[x, y] <- if board.[x, y] = 0 then 1 else 0 else if mouseState.LeftButton <> ButtonState.Pressed then // マウスボタンをリリース mouseButtonReleased <- true // Pキーによる、PAUSE ON/OFF previousRunFlg <- nowRunFlg nowRunFlg <- keyboardState.IsKeyDown(Keys.P) if nowRunFlg && not previousRunFlg then runFlg <- not runFlg // ライフゲーム状態の更新 let updateState = let updateBoard () = let tmp = Array2D.create width height 0 for x in [0..width-1] do for y in [0..height-1] do let inputs = // x7:左上, x8;上, x9:右上, x4:左, x5:評価対象のセル, x6:右, x1:左下, x2:下, x3:右下 let x7 = if x-1 >= 0 && y-1 >= 0 && board.[x-1, y-1] = 1 then 1.0 else 0.0 let x8 = if y-1 >= 0 && board.[x, y-1] = 1 then 1.0 else 0.0 let x9 = if x+1 < width && y-1 >= 0 && board.[x+1, y-1] = 1 then 1.0 else 0.0 let x4 = if x-1 >= 0 && board.[x-1, y] = 1 then 1.0 else 0.0 let x5 = board.[x, y] |> float let x6 = if x+1 < width && board.[x+1, y] = 1 then 1.0 else 0.0 let x1 = if x-1 > 0 && y+1 < height && board.[x-1, y+1] = 1 then 1.0 else 0.0 let x2 = if y+1 < height && board.[x, y+1] = 1 then 1.0 else 0.0 let x3 = if x+1 < width && y+1 < height && board.[x+1, y+1] = 1 then 1.0 else 0.0 // ライフゲームのパターン [x7;x8;x9; x4;x5;x6; x1;x2;x3] // ニューラルネットワークで判定 let outputs = networkActivate network { Inputs=inputs; TeachingSignal = []} // パターンに対する出力を取得 let output = Convert.ToInt32(outputs.[0]) tmp.[x, y] <- output // ボードに状態を反映 for x in [0..width-1] do for y in [0..height-1] do board.[x, y] <- tmp.[x, y] let settim : double ref = ref 0.0 (fun (gameTime:GameTime) -> if runFlg then let nowMillSeconds = gameTime.TotalGameTime.TotalMilliseconds if !settim + interval < nowMillSeconds then settim := nowMillSeconds // インターバルごとに状態を更新 updateBoard()) let update = let lag = 300. let wait = ref 0. // ニューラルネットワークに訓練データを読み込み network <- loadPatterns network lifeGameTrainingData (fun gameTime -> wait := !wait + 60. if !wait > lag then wait := 0. if not trainingEnd then // 訓練データをロード if error > 0.1 then // ニューラルネットワークを訓練する let nw,err = training network network <- nw; error <- err if network.TryCount > network.RestartAfter then // 乱数の具合が悪かったり、ローカルミニマムにハマったりで訓練がなかなか終わらない場合は、最初から訓練しなおしてみる network <- initializeNetwork network else // 訓練おわりやしたー trainingEnd <- true else // ニューラルネットワークの訓練が終了したら、キー入力を受け付けたりライフゲームを開始 operateKeys () updateState gameTime) do // タイトルを設定 this.Window.Title <- gametitle // ゲームループの間隔を設定 (60FPS) this.TargetElapsedTime <- TimeSpan.FromSeconds(1.0 / 60.) // マウスカーソルを表示 this.IsMouseVisible <- true override this.Initialize() = // ゲームウィンドウのサイズを設定 gmanager.PreferredBackBufferWidth <- this.Width gmanager.PreferredBackBufferHeight <- this.Height base.Initialize () /// ウィンドウの幅 member this.Width with get () = cellWidth * width /// ウィンドウの高さ member this.Height with get () = cellHeight * height /// ライフゲームの状態を更新 override this.Update (gameTime:GameTime) = base.Update gameTime if Keyboard.GetState().IsKeyDown(Keys.Escape) then // Escが押されたらおしまい this.Exit() // ライフゲームクラスの状態を更新 update gameTime /// ライフゲームの状態を描画 override this.Draw (gameTime:GameTime) = base.Draw gameTime // テクスチャーデータのサンプリング方法をClampに設定 gmanager.GraphicsDevice.SamplerStates.[1] <- new SamplerState(AddressU = TextureAddressMode.Clamp, AddressV = TextureAddressMode.Clamp, AddressW = TextureAddressMode.Clamp) // 背景を黒で塗りつぶし gmanager.GraphicsDevice.Clear(Color.Black) // ライフゲームクラスの状態を描画 if not trainingEnd then // ニューラルネットワークの訓練が終わるまでは、訓練の進捗を描画 spriteBatch.Force().Begin() spriteBatch.Force().DrawString(font.Force (), String.Format("NeuralNework Training... Try:{0,3:##0}; Error:{1}", network.TryCount, error), Vector2(0.0f,0.0f), Color.White) spriteBatch.Force().End() else // 訓練終了後は、ライフゲームの状態を描画 for x in [0..width-1] do for y in [0..height-1] do let pos = getPos x y if board.[x, y] = 0 then // 死んでるセルは真っ黒くろ助 spriteBatch.Force().Begin() spriteBatch.Force().Draw(textureCell.Force(), pos, Color.Black) spriteBatch.Force().End() else // 生きてるセルは、セルのテクスチャを描画 // テクスチャの描画に使用するエフェクトの設定 let spinningLight = moveInCircle gameTime 5.0 let time = gameTime.TotalGameTime.TotalSeconds let tiltUpAndDown = 0.5f + float32(Math.Cos(time * 0.75)) * 0.1f let lightDirection = new Vector3(spinningLight * tiltUpAndDown / 2.0f, tiltUpAndDown / 2.0f) lightDirection.Normalize() normalmapEffect.Force().Parameters.["LightDirection"].SetValue(lightDirection) gmanager.GraphicsDevice.Textures.[1] <- normalmapTextureCell.Force() // HLSLのエフェクトを使用して、セルのテクスチャを描画 spriteBatch.Force().Begin(SpriteSortMode.Deferred, BlendState.AlphaBlend, null, null, null, normalmapEffect.Force()) spriteBatch.Force().Draw(textureCell.Force(), pos, Color.White) spriteBatch.Force().End()
ライフゲームの生死判定を学習させるための訓練データは、F#で順列(Permutation)と組み合わせ(Combination)。YOU、Listモナドしちゃいなよ。集合モナドもあるよ。で書いた、
組み合わせ(Combination)を用いて全512パターンを作成しています。
セルを表している「はぐれメタル」の描画には、無駄にHLSL(High Level Shader Language)を使用しています。
HLSL
float3 LightDirection; float3 LightColor = 2.0; float3 AmbientColor = 0.1; sampler TextureSampler : register(s0); sampler NormalSampler : register(s1); float4 main(float4 color : COLOR0, float2 texCoord : TEXCOORD0) : COLOR0 { float4 tex = tex2D(TextureSampler, texCoord); float3 normal = tex2D(NormalSampler, texCoord); float lightAmount = max(dot(normal, LightDirection), 0.2); color.rgb *= AmbientColor + lightAmount * LightColor; return tex * color; } technique Normalmap { pass Pass1 { PixelShader = compile ps_2_0 main(); } }
errorが0.1以下になるまで訓練するようにしているので、ローカルミニマムにハマってしまい、なかなか最後まで学習が完了しない。
早く収束させるには、中間層の隠れニューロンの数を調整したり訓練を甘くして学習レベルを下げるとよい。
この実装では運に左右される。ローカルミニマムに陥る問題を避ける方法はいくつかあるようだが、それはまた別のお話。
SkyDriveに、F#でニューラルネットワークなソースコード一式を置いておきます。
SkyDrive - NN.zip