今回は取得したデータを使用して学習するところになります。
学習する際の条件は以下のようになります。

隠れ層数

隠れ層1 =16
隠れ層2 =24

読み込みデータ量

各Dataのデータを5,000ずつ、合計16,0000件以上を目標

流れ

学習させるためのPythonコードはTensorFlowに同梱されているTutorial"MNIST"を改造して作成しています。
変更点としては以下のようになります。
  1. Unlimited Handのセンサーから取得したデータを読み込む
  2. 学習結果をProtocol Bufferデータ(PBファイル)として出力
上記2点についてはコードを参照しながら説明していきたいと思います。
fully_connected_feed.py - run_training()
                : 省略
        data_files = []
        if FLAGS.random_learning:
            max_read_step, out_file_name = create_random_data_file()
            data_files = [out_file_name]
        else:
            data_files = glob.glob(FLAGS.input_data_dir + "/sensor_data_*")

        total_read_step = len(data_files) * FLAGS.offset

        for data_file in data_files:
            print('%s: ' % data_file)
            start_offset_step = offset_step = FLAGS.offset

            while True:
                # read data_sets from CVS
                data_sets = read_sensor_data_sets(data_file, offset_step=offset_step, read_step=read_step)

                if data_sets != None:
                    # Start the training loop.
                    start_time = time.time()

FLAGS.random_learningですが、各教師データ毎の学習データを順番に学習させるか、ランダムな順番に学習させるかの制御になります。
FLAGS.random_learning
無効
FLAGS.random_learning
有効
教師データ00000の学習データ(全データ)実施
教師データ00001の学習データ(全データ)実施
        :32パターンの教師データに対する
          学習データを順番に学習
教師データ11110の学習データ(全データ)実施
教師データ11111の学習データ(全データ)実施
=> 終了
教師データ10101の学習データ(1データ)実施
教師データ00001の学習データ(1データ)実施
        :32パターンの教師データに対する
          学習データをランダムな順番で学習
教師データ11110の学習データ(1データ)実施
=> 32パターン全て1つずつ学習したら、再度32パターンの教師データに対する学習データをランダムな順番で学習

学習効果は結果的には変わらないようです。FLAGS.random_learningを有効にした方がTensorBoardでlossの値の遷移グラフがきれいになりますが、推論結果には大きな影響は見られませんでした。
FLAGS.random_learning
無効
FLAGS.random_learning
有効

read_sensor_data_setsで今回読み込みの対象となっているファイルからセンサーデータを読み込み、MNISTでも使用しているDataSet型で返却しています。
fully_connected_feed.py - run_training()
        graph_io.write_graph(sess.graph, FLAGS.saved_data_dir, "saved_data.pb", as_text=False)
        input_binary = True

        input_graph_path = os.path.join(FLAGS.saved_data_dir, "saved_data.pb")
        input_saver = ""
        output_node_names = "eval_correct"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(FLAGS.saved_data_dir, "saved_data_out.pb")
        clear_devices = False

        freeze_graph.freeze_graph(input_graph_path, input_saver,
                                  input_binary, checkpoint, output_node_names,
                                  restore_op_name, filename_tensor_name,
                                  output_graph_path, clear_devices, "", "")

最初はtf.saved_modelで保存しようと試みたのですが、tf.saved_modelで保存したProtocol Bufferファイルでは、Androidで読み込むときに
  1. Variableが初期化される前に使用された
  2. 不明なOpが存在する
といったエラーが発生しました。
1つ目のエラーはtf.saved_modelで保存されている情報には設定されている値が書き出されないために発生するとのことでした(詳細はこちらを参照)
2つ目のエラーはAndroid側で使用しているTensorFlow用のライブラリがinferenceのみのサブセットとなっているため、training時に使用したOpが存在しないために発生しているようです。
そのため、上のStackOverflowを頼りfreeze_graphを使用しています。
この際の注意点としては「output_node_names = "eval_correct"」になります。
これはProtocol Bufferに出力するOpを明示します。このOpはAndroid側でコールして推論するときに使用したいOpになります。

出力したProtocol BufferデータをAPKに組み込む

組み込む方法はある程度自由ですが、今回はassetsとしてAPKに組み込んでいます。
また推論を行うためのTensorFlowライブラリですが、TensorFlowが公式にCIにてビルドしたものをこちらで配布しているので、それを使用しています。
Protocol Bufferファイルへのパス(※)を以下のmlPbFileに設定することで、学習結果を使用した推論が使用できるようになります。
※ 設定するパスは以下のようにする必要があります。
Assets内のファイルの場合:file:///android_asset/[Assets内のパス]
外部ストレージ上のファイルの場合:fileスキームなしのパス(例:/storage/emulated/0/XXXXX)
fully_connected_feed.py - run_training()
mInterface = new TensorFlowInferenceInterface(context.getAssets(), mlPbFile);

そしてfeedメソッドでPlaceHolder"sensor_values_placeholder:0"にセンサーから取得したデータを学習時と同じように加工したものを設定、"labels_placeholder:0"には比較する指の状態を設定します。
その後、runメソッドを実行することで推論が実行され、fetchメソッドでrunの結果を得られます。この場合は、"labels_placeholder:0"に設定したデータと一致しているか否かがfetchにて取得できるので、一致したら判定を終了させます。
UhGestureDetector2.java - UhGestureDetector2
for (int i = 0, size = (int) Math.pow(2, 5); i < size; i++) {
    try {
        // TensorFlow_NightlyBuild
        mInterface.feed(INPUT_SENSOR_NAME, sensorValueArray, 1, sensorValueArray.length);
        mInterface.feed(INPUT_LABEL_NAME, new int[]{i}, 1);
        mInterface.run(new String[]{OUTPUT_GESTURE_NAME});
        mInterface.fetch(OUTPUT_GESTURE_NAME, inferenceResult);
        if (inferenceResult[0] != 0) {
            value = i;
            break;
        }
    } catch (Exception e) {
        successFetch = false;
        LogUtil.exception(TAG, e);
    }
}

コードを参照したい方は以下のものを参照してください。
・学習用(Python):https://github.com/eq-inc/eq-tensorflow-learn_uh_sensor_values/tree/v0.0.7
・推論用(Android: Java):https://github.com/thcomp/Android_UnlimitedHandAccessHelper/tree/v0.0.9
 unlimitedhand/src/main/java/jp/co/thcomp/unlimitedhand/UhGestureDetector2.java

コメントの投稿