FC2ブログ

頭と尻尾はくれてやる!

パソコンおやじのiPhoneアプリ・サイト作成・運営日記


MPSImageオブジェクトから予測結果のfloat値を得る

Metal Performance Shadersで機械学習定番である手書き数字認識のMNISTをやった場合、最終的には10個のアウトプットの数値(それぞれの数字の確率)が必要になる。
MPSだとMPSImageオブジェクトにアウトプットが入るんだがそこから値を得るのに(なぜか今頃)はまったのでメモ。

MNIST用のNNを構成した場合、最終的にはたいていsoftmax関数の出力を得ると思う。
{
    MPSCNNSoftMax *softmax;


    softmax = [[MPSCNNSoftMax alloc] initWithDevice:device];

    MPSImageDescriptor *fc2id = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 width:1 height:1 featureChannels:10];//—(1)

    finalImage = [[MPSImage alloc] initWithDevice:device imageDescriptor:fc2id];


    [softmax encodeToCommandBuffer:commandBuffer sourceImage:fc2Image destinationImage:finalImage];
}
↑出力層あたりの関連部分だけ書いてるけど、(1)に示すように最終的な出力はwidth1 x height 1 x channel10 で、その各値は32bit(※)の浮動小数点数のMPSImageオブジェクトに入ることになる。

これをObjective-Cのfloat値として読むのがこんな感じ↓
{
    float *outputs =  calloc(12, sizeof(float));

    MTLRegion region = MTLRegionMake3D(0, 0, 0, 1, 1, 1);
    
    for (NSUInteger ite=0;ite<3;ite++) {
        [finalImage.texture getBytes:&(outputs[4*ite])
                         bytesPerRow:sizeof(float)*4
                       bytesPerImage:sizeof(float)*4
                          fromRegion:region
                         mipmapLevel:0
                               slice:ite];
    }
}
これで各値はoutputs[ite]でアクセスできる。4つのfloat値を3回に分けて取得してる。なので一応10個ではなく12個のメモリを確保(calloc)してる。


なお、最初にはまっていたのは下のようなコードで値を得ようとしていたから。
{
                [finalImage.texture getBytes:&outputs[0]
                                       bytesPerRow:sizeof(float)*4
                                        fromRegion:filnalImageRegion
                                       mipmapLevel:0];
}
これだと4つしか値が取得できない。
これ強化学習の倒立振子で使っていた。倒立振子だと出力が二つ(どちらの方向に回転するとよいか?)だったのでたまたま問題なかった。


(※)
Objective-Cで使うfloatは32bitなので問題ないけど、(俺の勘違いでなければ)MPSがリリースされた当初は出力に16bitの浮動小数点数しか使えず、その出力結果をObjective-Cで32bitのfloatで得るのがすごく面倒だった。

<< Kerasの学習後パラメータを保存する(CNNの場合)  TopPage  SceneKitでオブジェクト表面に背景画像を表示する >>

コメント


管理者にだけ表示を許可する
 

トラックバック

トラックバックURL
http://ringsbell.blog117.fc2.com/tb.php/1210-32c4e13a




Copyright ©頭と尻尾はくれてやる!. Powered by FC2 Blog. Template by eriraha.

FC2Ad