tensorflow.js の mnist (数値の画像分類) のコードを読んで勉強した

tfjs の練習がてら、ちょっとやってみた。大学生の頃PRMLを二章ぐらいまで読んでギブアップした程度の知識なので、ほぼ無知。

js.tensorflow.org

間違ってたら教えてほしい…というか特に予習とかせず生半可な知識でやったのでたぶん間違ってる。

入力と出力

  • 32(横)x32(縦)x2(白黒) の画像が入力値(数値が書いてある画像)
  • (なんかいろいろあって) 0 から 9 いずれかの数値に分類する

そのまま入力して関係を見出すのが困難なので、なんやかんやで次元を少なくする

  • 元の画像の入力を8x8 の二次元のベクトル(Tensor)に圧縮
  • 64の入力に対し、いずれかの数値に分類する

そういう仮説が事前にある、というのを受け入れるのが時間かかった。 さすがに画像を放り込むと、なんかすごいディープラーニング様ってやつが出力当ててくれるというものではない。

畳み込み(CNN)

  • ある x, y に対し、その周辺8マスを含む 3x3 の領域で白黒を判定する
  • 2x2 の領域を見て、その中間値で1の値に圧縮する(Pooling)
  • これを二回繰り返す(32*32/4/4=64)

というモデルを表すコードがこれ

function createConvModel() {
  // Create a sequential neural network model. tf.sequential provides an API
  // for creating "stacked" models where the output from one layer is used as
  // the input to the next layer.
  const model = tf.sequential();

  // The first layer of the convolutional neural network plays a dual role:
  // it is both the input layer of the neural network and a layer that performs
  // the first convolution operation on the input. It receives the 28x28 pixels
  // black and white images. This input layer uses 16 filters with a kernel size
  // of 5 pixels each. It uses a simple RELU activation function which pretty
  // much just looks like this: __/
  model.add(
    tf.layers.conv2d({
      inputShape: [IMAGE_H, IMAGE_W, 1],
      kernelSize: 3,
      filters: 16,
      activation: "relu"
    })
  );

  // After the first layer we include a MaxPooling layer. This acts as a sort of
  // downsampling using max values in a region instead of averaging.
  // https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));

  // Our third layer is another convolution, this time with 32 filters.
  model.add(
    tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: "relu" })
  );

  // Max pooling again.
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));

  // Add another conv2d layer.
  model.add(
    tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: "relu" })
  );

  // Now we flatten the output from the 2D filters into a 1D vector to prepare
  // it for input into our last layer. This is common practice when feeding
  // higher dimensional data to a final classification output layer.
  model.add(tf.layers.flatten({}));

  model.add(tf.layers.dense({ units: 64, activation: "relu" }));

  // Our last layer is a dense layer which has 10 output units, one for each
  // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually
  // represent numbers, but it's the same idea if you had classes that
  // represented other entities like dogs and cats (two output classes: 0, 1).
  // We use the softmax function as the activation for the output layer as it
  // creates a probability distribution over our 10 classes so their output
  // values sum to 1.
  model.add(tf.layers.dense({ units: 10, activation: "softmax" }));

  return model;
}

学習

画像とその数値の正解データセットが用意されてるので、それを一つずつ入力する。どの数値っぽいかという出力が出てくる。 出力に対し、正解の教師データと突き合わせて categoricalCrossentropy (ググると Multi-class logloss という名前のが知られている)で、正解との距離を対数で取ったものの和を誤差として用い、 LEARNING_RATE = 0.01 でバックプロパゲーション(誤差伝搬)させる。これでニューラルネット間の係数が修正されていく。

  // optimizer 定義
  model.compile({
    optimizer,
    loss: "categoricalCrossentropy",
    metrics: ["accuracy"]
  });

  // 訓練
  await model.fit(trainData.xs, trainData.labels, {
    batchSize,
    validationSplit,
    epochs: trainEpochs,
    callbacks: {
      onBatchEnd: async (batch, logs) => {
        trainBatchCount++;
        ui.logStatus(
          `Training... (` +
            `${((trainBatchCount / totalNumBatches) * 100).toFixed(1)}%` +
            ` complete). To stop training, refresh or close page.`
        );
        ui.plotLoss(trainBatchCount, logs.loss, "train");
        ui.plotAccuracy(trainBatchCount, logs.acc, "train");
        await tf.nextFrame();
      },
      onEpochEnd: async (epoch, logs) => {
        valAcc = logs.val_acc;
        ui.plotLoss(trainBatchCount, logs.val_loss, "validation");
        ui.plotAccuracy(trainBatchCount, logs.val_acc, "validation");
        await tf.nextFrame();
      }
    }
  });

  // 訓練されたモデルを使う
  const testResult = model.evaluate(testData.xs, testData.labels);

感想

CNN って想像していたニューラルネットっぽくないというか、これ単にヒューリスティックだった下処理に名前が付いてるやつでは?という気持ちになった。

なんか頑張って読むことは出来たが、自分で作れといわれてたら無理なので、素振りする必要がありそう。