ようこそ。睡眠不足なプログラマのチラ裏です。

メモ化を抽象的に考えて一般化する。これぞジェネリックプログラミングの神髄!なんつってー

前回のエントリーで書いた関数の「メモ化」について抽象的に考えて、
ジェネリックプログラミングをしてメモ化を一般化してみましょう。
ググッたところ、1つの引数をとる関数のメモ化関数は多くの人が書いていますが、
2つ以上の引数をとる関数のメモ化関数については、
見あたりませんでしたので、気が向いたのでちょっと一肌脱いでみました。


関数のメモ化の一般化

いきなりですが、以下コードです。

using System;
using System.Linq;
using System.Collections.Generic;

namespace ConsoleApplication1
{
    /// <summary>
    /// 任意の関数についてのメモ化を提供します。
    /// 
    /// ただし、任意の関数内にクロージャを含まない前提で利用してください。
    /// 関数内にクロージャを含む場合、クロージャ内の環境変化は無視されることに注意してください。
    /// </summary>
    public static class Memoization
    {
        /// <summary>
        /// Func{TResult}について関数をメモ化します。
        /// </summary>
        /// <typeparam name="TArgs"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="f"></param>
        /// <returns></returns>
        public static Func<TResult> Memoize<TResult>(this Func<TResult> f)
        {
            var value = default(TResult);
            bool hasValue = false;
            return () =>
            {
                if (!hasValue)
                {
                    hasValue = true;
                    value = f();
                }
                return value;
            };
        }

        /// <summary>
        /// Func{TArgs, TResult}について関数をメモ化します。
        /// </summary>
        /// <typeparam name="TArgs"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="f"></param>
        /// <returns></returns>
        public static Func<TArgs, TResult> Memoize<TArgs, TResult>(this Func<TArgs, TResult> f)
        {
            var dic = new Dictionary<TArgs, TResult>();
            return x =>
            {
                if (dic.ContainsKey(x)) return dic[x];
                return dic[x] = f(x);
            };
        }

        /// <summary>
        /// Func{TArgs1, TArgs2, TResult}について関数をメモ化します。
        /// </summary>
        /// <typeparam name="TArgs1"></typeparam>
        /// <typeparam name="TArgs2"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="f"></param>
        /// <returns></returns>
        public static Func<TArgs1, TArgs2, TResult> Memoize<TArgs1, TArgs2, TResult>(this Func<TArgs1, TArgs2, TResult> f)
        {
            var key = new { a = default(TArgs1), b = default(TArgs2) };
            var dynamicList = DynamicCreateGenericType(typeof(Dictionary<,>), new Type[] { key.GetType(), typeof(TResult) });
            var add = dynamicList.GetType().GetMethod("Add");
            var contains = dynamicList.GetType().GetMethod("ContainsKey");
            var items = dynamicList.GetType().GetMethod("get_Item");

            return (x, y) =>
            {
                key = new { a = x, b = y };
                var exist = (bool)contains.Invoke(dynamicList, new object[] { key });
                if (exist) return (TResult)items.Invoke(dynamicList, new object[] { key });
                add.Invoke(dynamicList, new object[] { key, f(key.a, key.b) });
                return (TResult)items.Invoke(dynamicList, new object[] { key });
            };
        }

        /// <summary>
        /// Func{TArgs1, TArgs2, TArgs3, TResult}について関数をメモ化します。
        /// </summary>
        /// <typeparam name="TArgs1"></typeparam>
        /// <typeparam name="TArgs2"></typeparam>
        /// <typeparam name="TArgs3"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="f"></param>
        /// <returns></returns>
        public static Func<TArgs1, TArgs2, TArgs3, TResult> Memoize<TArgs1, TArgs2, TArgs3, TResult>(this Func<TArgs1, TArgs2, TArgs3, TResult> f)
        {
            var key = new { a = default(TArgs1), b = default(TArgs2), c = default(TArgs3) };
            var dynamicList = DynamicCreateGenericType(typeof(Dictionary<,>), new Type[] { key.GetType(), typeof(TResult) });
            var add = dynamicList.GetType().GetMethod("Add");
            var contains = dynamicList.GetType().GetMethod("ContainsKey");
            var items = dynamicList.GetType().GetMethod("get_Item");

            return (x, y, z) =>
            {
                key = new { a = x, b = y, c = z };
                var exist = (bool)contains.Invoke(dynamicList, new object[] { key });
                if (exist) return (TResult)items.Invoke(dynamicList, new object[] { key });
                add.Invoke(dynamicList, new object[] { key, f(key.a, key.b, key.c) });
                return (TResult)items.Invoke(dynamicList, new object[] { key });
            };
        }

        /// <summary>
        /// 任意のジェネリック型を動的に生成します。
        /// </summary>
        /// <typeparam name="T">
        /// ジェネリック型を構築する元になるジェネリック型定義
        /// 例えば、List{}やDictionary{,}など
        /// </typeparam>
        /// <param name="genericArgMetadata">ジェネリック型に必要な引数メタデータ</param>
        /// <returns>任意のジェネリック型のインスタンス</returns>
        public static object DynamicCreateGenericType(Type genericType, Type[] genericArgMetadata)
        {
            if (!genericType.IsGenericType || genericType != genericType.GetGenericTypeDefinition())
                throw new ArgumentException("指定されたgenericTypeは、ジェネリック型を構築する元になるジェネリック型定義(GenericTypeDefinition)ではありません。");
            var genericTypeArgumentCount = genericType.GetGenericArguments().Count();
            if (genericTypeArgumentCount != genericArgMetadata.Count())
                throw new ArgumentOutOfRangeException("生成するジェネリック型の引数の数と、Type[]メタデータの数が異なります。");

            Type gtd = genericType.GetGenericTypeDefinition();
            Type dgtype = gtd.MakeGenericType(genericArgMetadata);
            var dg = Activator.CreateInstance(dgtype);
            return dg;
        }
    }
}

見所は、なんといっても2引数および3引数をとるメモ化関数でしょう。
以前から愛用している「任意のジェネリック型を動的に生成」するスニペットが大活躍です。
匿名クラスを用いているところがミソですね。(C#4.0であればTupleで代用できる)
その匿名クラスの型を指定したジェネリックコレクションを
動的に生成しているところに、男気を感じて頂けるとありがたいです。
これぞジェネリックプログラミングの神髄!なんつってー


お試し

メモ化関数の動作を試してみましょう。

using System;
using System.Linq;
using System.Diagnostics;
using System.Threading;

namespace ConsoleApplication1
{
    class Program
    {
        static void Main(string[] args)
        {
            Console.WriteLine("Test1-----------------------------");
            {
                Func<int, int> fib = null;
                fib = n => n > 1 ? fib(n - 1) + fib(n - 2) : n;
                var sw = new Stopwatch();
                sw.Start();
                var r1 = fib(30);
                sw.Stop();
                Console.WriteLine("{0}:{1}",r1,sw.ElapsedTicks);

                //メモ化
                fib = fib.Memoize();

                sw.Restart();
                var r2 = fib(30);
                sw.Stop();
                Console.WriteLine("{0}:{1}", r2, sw.ElapsedTicks);
            }

            Console.WriteLine("Test2-----------------------------");
            {
                Func<int> memofunc = () => Mossari0();
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc());

                //メモ化
                memofunc = memofunc.Memoize();
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc());
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc());
            }

            Console.WriteLine("Test3-----------------------------");
            {
                Func<int, int> memofunc = x => Mossari1(x);
                memofunc = memofunc.Memoize();
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc(i));
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc(i));
            }

            Console.WriteLine("Test4-----------------------------");
            {
                Func<int, string, string> memofunc2 = (x, y) => Mossari2(x, y);
                memofunc2 = memofunc2.Memoize();
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc2(i, "メセタ"));
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc2(i, "メセタ"));
            }

            Console.WriteLine("Test5-----------------------------");
            {
                Func<int, int, string, string> memofunc3 = (x, y, z) => Mossari3(x, y, z);
                memofunc3 = memofunc3.Memoize();
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc3(i, i, "へぇ"));
                foreach (var i in Enumerable.Range(0, 5))
                    Console.WriteLine(memofunc3(i, i, "へぇ"));
            }

            Console.ReadKey();
        }

        private static int i = 0;
        static int Mossari0()
        {
            Thread.Sleep(1000);
            //メモ化すると、このクロージャの環境はそれ以降は評価されない
            Func<int> f = () => ++i;
            return f();
        }

        static int Mossari1(int s)
        {
            Thread.Sleep(1000);
            return s + 1;
        }

        static string Mossari2(int i, string s)
        {
            Thread.Sleep(1000);
            return ++i + s;
        }

        static string Mossari3(int i, int j, string s)
        {
            Thread.Sleep(1000);
            return (i * j) + s;
        }
    }
}

いい感じです。
フィボナッチな関数もかなり高速化しますね。


あと、コメントにも書いていますが、
このメモ化関数を用いて関数をメモ化すると、メモ化対象関数内のクロージャの環境は、
メモ化以降には評価されないという点には注意が必要です。その点にさえ注意すれば、
このようにメモ化関数を一般化することで、使いたいときにささっとメモ化ができるのは、
なかなか嬉しいんじゃないでしょうか。何かの参考になれば幸いです。