今回の記事ではTensorFlowLiteをSwiftで使用する、「Object Detection API」の中ファイルである「ModelDataHandler.swift」の解説記事です。初心者でも理解できるように詳しく解説していきますので是非参考にしてみてください。
・object detection API がどの様に動いているか?
・object detection API の入出力について
・object detection API のModelDataHandler.swiftの動きについて ←今回
・object detection API のCVPixelBufferExtension.swiftの動きについて
・object detection API のViewController.swiftの動きについて
該当のgithubのレポジトリはこちらです。
https://github.com/tensorflow/examples.git
ではメインの記事に進んでいきます。
Object Detection APIのモジュール
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import CoreImage
import TensorFlowLite
import UIKit
import Accelerate
まずは各モジュールをインストールしています。
TensorFlowLiteに関しては初期状態のXcodeプロジェクトには入っていないのでpodfileのカレントでpod installをしてモジュールをこのプロジェクトに足してください。
Object Detection APIの変数定義
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
/// Stores results for a particular frame that was successfully run through the `Interpreter`.
struct Result {
let inferenceTime: Double
let inferences: [Inference]
}
/// Stores one formatted inference.
struct Inference {
let confidence: Float
let className: String
let rect: CGRect
let displayColor: UIColor
}
/// Information about a model file or labels file.
typealias FileInfo = (name: String, extension: String)
/// Information about the MobileNet SSD model.
enum MobileNetSSD {
static let modelInfo: FileInfo = (name: "detect", extension: "tflite")
static let labelsInfo: FileInfo = (name: "labelmap", extension: "txt")
}
次のコードでは変数の宣言を行なっています。ここに書かれているものはclass外のものでそれらは共通して他のファイルと連携しているものであり最重要です。
structやenumの形で扱いやすくしてくれていますが初心者からすると難しいです。下記記事でenumとstructの使い方に関して簡単に説明しているので是非参考にしてみてください。
「Swiftでstructを使う方法」
「Swiftでenumを使う方法」
ついでにクラスの使い方も理解したい方はリンク先の記事にリンクがあるのでそちらから参照ください。
class外のstruct Result,Inferenceは最終的にラベル(タグ)と出力結果を結びつけるものと学習モデルから出力されたものになっておりこのプロジェクト内で最も重要です。
enum MobileNetSSDには今回使用するモデルの名前を記載しています。
もしもオリジナルのモデルを使う場合はここの名称の変更と、Modelフォルダの中に学習モデルを入れる必要があります。
/// This class handles all data preprocessing and makes calls to run inference on a given frame
/// by invoking the `Interpreter`. It then formats the inferences obtained and returns the top N
/// results for a successful inference.
class ModelDataHandler: NSObject {
// MARK: - Internal Properties
/// The current thread count used by the TensorFlow Lite Interpreter.
let threadCount: Int
let threadCountLimit = 10
let threshold: Float = 0.5
// MARK: Model parameters
let batchSize = 1
let inputChannels = 3
let inputWidth = 300
let inputHeight = 300
// image mean and std for floating model, should be consistent with parameters used in model training
let imageMean: Float = 127.5
let imageStd: Float = 127.5
// MARK: Private properties
private var labels: [String] = []
/// TensorFlow Lite `Interpreter` object for performing inference on a given model.
private var interpreter: Interpreter
private let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2)
private let rgbPixelChannels = 3
private let colorStrideValue = 10
private let colors = [
UIColor.red,
UIColor(displayP3Red: 90.0/255.0, green: 200.0/255.0, blue: 250.0/255.0, alpha: 1.0),
UIColor.green,
UIColor.orange,
UIColor.blue,
UIColor.purple,
UIColor.magenta,
UIColor.yellow,
UIColor.cyan,
UIColor.brown
]
ここまでが変数の宣言になっております。こちらは先ほどの変数に対してclass内にて宣言されています。
UI画面の設定がこの辺りでほぼ行われています。
inputWidthとHeightがpixcelbufferからCIImageに変換された画像のSizeになっており、学習モデルのインプット画像のサイズと合わせる必要があります。
label、interpreterは最終的な物体検知後の出力用配列と物体検知tensorflowliteとモデルを連携させるための関数用の変数と考えれば良いです。
interpreterについては奥深い気がするので別記事でも紹介しています。
Object Detection APIの初期化
// MARK: - Initialization
/// A failable initializer for `ModelDataHandler`. A new instance is created if the model and
/// labels files are successfully loaded from the app's main bundle. Default `threadCount` is 1.
init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int = 1) {
let modelFilename = modelFileInfo.name
// Construct the path to the model file.
guard let modelPath = Bundle.main.path(
forResource: modelFilename,
ofType: modelFileInfo.extension
) else {
print("Failed to load the model file with name: \(modelFilename).")
return nil
}
// Specify the options for the `Interpreter`.
self.threadCount = threadCount
var options = Interpreter.Options()
options.threadCount = threadCount
do {
// Create the `Interpreter`.
interpreter = try Interpreter(modelPath: modelPath, options: options)
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
} catch let error {
print("Failed to create the interpreter with error: \(error.localizedDescription)")
return nil
}
super.init()
// Load the classes listed in the labels file.
loadLabels(fileInfo: labelsFileInfo)
}
ここまで出てきた変数の初期化を行なっています。
またinterpreterに使用するモデルのセット、最後にlabelfileInfoにてラベルマップの情報を読み込む関数loadLabelsを使用しています。
以下からは関数になっております。
特筆していない関数は全てModelDataHandlerクラスの関数です。
/// This class handles all data preprocessing and makes calls to run inference on a given frame
/// through the `Interpeter`. It then formats the inferences obtained and returns the top N
/// results for a successful inference.
func runModel(onFrame pixelBuffer: CVPixelBuffer) -> Result? {
let imageWidth = CVPixelBufferGetWidth(pixelBuffer)
let imageHeight = CVPixelBufferGetHeight(pixelBuffer)
let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
sourcePixelFormat == kCVPixelFormatType_32BGRA ||
sourcePixelFormat == kCVPixelFormatType_32RGBA)
let imageChannels = 4
assert(imageChannels >= inputChannels)
// Crops the image to the biggest square in the center and scales it down to model dimensions.
let scaledSize = CGSize(width: inputWidth, height: inputHeight)
guard let scaledPixelBuffer = pixelBuffer.resized(to: scaledSize) else {
return nil
}
let interval: TimeInterval
let outputBoundingBox: Tensor
let outputClasses: Tensor
let outputScores: Tensor
let outputCount: Tensor
do {
let inputTensor = try interpreter.input(at: 0)
// Remove the alpha component from the image buffer to get the RGB data.
guard let rgbData = rgbDataFromBuffer(
scaledPixelBuffer,
byteCount: batchSize * inputWidth * inputHeight * inputChannels,
isModelQuantized: inputTensor.dataType == .uInt8
) else {
print("Failed to convert the image buffer to RGB data.")
return nil
}
// Copy the RGB data to the input `Tensor`.
try interpreter.copy(rgbData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
let startDate = Date()
try interpreter.invoke()
interval = Date().timeIntervalSince(startDate) * 1000
outputBoundingBox = try interpreter.output(at: 0)
outputClasses = try interpreter.output(at: 1)
outputScores = try interpreter.output(at: 2)
outputCount = try interpreter.output(at: 3)
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
return nil
}
// Formats the results
let resultArray = formatResults(
boundingBox: [Float](unsafeData: outputBoundingBox.data) ?? [],
outputClasses: [Float](unsafeData: outputClasses.data) ?? [],
outputScores: [Float](unsafeData: outputScores.data) ?? [],
outputCount: Int(([Float](unsafeData: outputCount.data) ?? [0])[0]),
width: CGFloat(imageWidth),
height: CGFloat(imageHeight)
)
// Returns the inference time and inferences
let result = Result(inferenceTime: interval, inferences: resultArray)
return result
}
いきなり出てきましたが最も大切な関数です。
この「runmodel」関数が物体検知を行なっているモデルを動かしている関数です。
上から順に流れを追って行きます。
まず入力はonFrame pixelBuffer: CVPixelBuffer、要は背面カメラから入ってくるCVPixelBuffeとなております。
出力の型はResult?です。(?にすることで型の制約が緩くなりnilが戻り値になってもOKになります。これでguard文が使えるようになります。)Result?は初めの方に宣言していたグローバル変数のstruct Resultのことですので要素もinferenceTime: Doubleとinferences: [Inference]が出力されることがわかります。
このinferencesはさらにstruct Inferenceを持っていますのでconfidence: Float、className: String、rect: CGRect、displayColor: UIColorが格納されます。
始めに入力のbufferから画像のImageサイズ関連の情報を変数に格納し、そのサイズにbufferをリサイズしております。その時に使用しているresize関数はextensionファイルにて記載されている、pixelbufferの拡張機能ですのでextensionコード解説時に再度説明します。
次に出力先のTensorの変数を宣言します。outbundingboxが物体検知の領域を表している枠の座標のことです。他は英語の通りです。
そしてそのしたのdo catch文で推論を行なっております。
流れとしてはinterpreterにdataを入れるためにbufferをrgbdataに変換し、interpreterにセット、invoke()にてinterpreterに設定した学習モデルを使って設定したデータの物体検知を行っています。そしてinterpreter.outputにて出力された
Object Detection API出力を管理
/// Filters out all the results with confidence score < threshold and returns the top N results
/// sorted in descending order.
func formatResults(boundingBox: [Float], outputClasses: [Float], outputScores: [Float], outputCount: Int, width: CGFloat, height: CGFloat) -> [Inference]{
var resultsArray: [Inference] = []
if (outputCount == 0) {
return resultsArray
}
for i in 0...outputCount - 1 {
let score = outputScores[i]
// Filters results with confidence < threshold.
guard score >= threshold else {
continue
}
// Gets the output class names for detected classes from labels list.
let outputClassIndex = Int(outputClasses[i])
let outputClass = labels[outputClassIndex + 1]
var rect: CGRect = CGRect.zero
// Translates the detected bounding box to CGRect.
rect.origin.y = CGFloat(boundingBox[4*i])
rect.origin.x = CGFloat(boundingBox[4*i+1])
rect.size.height = CGFloat(boundingBox[4*i+2]) - rect.origin.y
rect.size.width = CGFloat(boundingBox[4*i+3]) - rect.origin.x
// The detected corners are for model dimensions. So we scale the rect with respect to the
// actual image dimensions.
let newRect = rect.applying(CGAffineTransform(scaleX: width, y: height))
// Gets the color assigned for the class
let colorToAssign = colorForClass(withIndex: outputClassIndex + 1)
let inference = Inference(confidence: score,
className: outputClass,
rect: newRect,
displayColor: colorToAssign)
resultsArray.append(inference)
}
// Sort results in descending order of confidence.
resultsArray.sort { (first, second) -> Bool in
return first.confidence > second.confidence
}
return resultsArray
}
出力されたResultをソートして出力を管理しています。ここで確率で切り捨てなどを行なっています。
Object Detection APIモデル読み込み
/// Loads the labels from the labels file and stores them in the `labels` property.
private func loadLabels(fileInfo: FileInfo) {
let filename = fileInfo.name
let fileExtension = fileInfo.extension
guard let fileURL = Bundle.main.url(forResource: filename, withExtension: fileExtension) else {
fatalError("Labels file not found in bundle. Please add a labels file with name " +
"\(filename).\(fileExtension) and try again.")
}
do {
let contents = try String(contentsOf: fileURL, encoding: .utf8)
labels = contents.components(separatedBy: .newlines)
} catch {
fatalError("Labels file named \(filename).\(fileExtension) cannot be read. Please add a " +
"valid labels file and try again.")
}
}
/// Returns the RGB data representation of the given image buffer with the specified `byteCount`.
///
/// - Parameters
/// - buffer: The BGRA pixel buffer to convert to RGB data.
/// - byteCount: The expected byte count for the RGB data calculated using the values that the
/// model was trained on: `batchSize * imageWidth * imageHeight * componentsCount`.
/// - isModelQuantized: Whether the model is quantized (i.e. fixed point values rather than
/// floating point values).
/// - Returns: The RGB data representation of the image buffer or `nil` if the buffer could not be
/// converted.
private func rgbDataFromBuffer(
_ buffer: CVPixelBuffer,
byteCount: Int,
isModelQuantized: Bool
) -> Data? {
CVPixelBufferLockBaseAddress(buffer, .readOnly)
defer {
CVPixelBufferUnlockBaseAddress(buffer, .readOnly)
}
guard let sourceData = CVPixelBufferGetBaseAddress(buffer) else {
return nil
}
let width = CVPixelBufferGetWidth(buffer)
let height = CVPixelBufferGetHeight(buffer)
let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(buffer)
let destinationChannelCount = 3
let destinationBytesPerRow = destinationChannelCount * width
var sourceBuffer = vImage_Buffer(data: sourceData,
height: vImagePixelCount(height),
width: vImagePixelCount(width),
rowBytes: sourceBytesPerRow)
guard let destinationData = malloc(height * destinationBytesPerRow) else {
print("Error: out of memory")
return nil
}
defer {
free(destinationData)
}
var destinationBuffer = vImage_Buffer(data: destinationData,
height: vImagePixelCount(height),
width: vImagePixelCount(width),
rowBytes: destinationBytesPerRow)
if (CVPixelBufferGetPixelFormatType(buffer) == kCVPixelFormatType_32BGRA){
vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags))
} else if (CVPixelBufferGetPixelFormatType(buffer) == kCVPixelFormatType_32ARGB) {
vImageConvert_ARGB8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags))
}
let byteData = Data(bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height)
if isModelQuantized {
return byteData
}
// Not quantized, convert to floats
let bytes = Array<UInt8>(unsafeData: byteData)!
var floats = [Float]()
for i in 0..<bytes.count {
floats.append((Float(bytes[i]) - imageMean) / imageStd)
}
return Data(copyingBufferOf: floats)
}
モデル読み込みように必要な関数を定義しています
Object Detection API枠色の調整
/// This assigns color for a particular class.
private func colorForClass(withIndex index: Int) -> UIColor {
// We have a set of colors and the depending upon a stride, it assigns variations to of the base
// colors to each object based on its index.
let baseColor = colors[index % colors.count]
var colorToAssign = baseColor
let percentage = CGFloat((colorStrideValue / 2 - index / colors.count) * colorStrideValue)
if let modifiedColor = baseColor.getModified(byPercentage: percentage) {
colorToAssign = modifiedColor
}
return colorToAssign
}
}
枠色を調節しています。
Object Detection API型のextension
// MARK: - Extensions
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
使用する型の拡張を多なっております。モデル使用時などに必要な処理です。
以上がObject Detection APIの「modeldatahandler.swift」の働きになります。非常に長く複雑です。読み解くのが一苦労です。
では今回の今回の記事は以上です。他にも多数のSwift関連や機械学習関連の記事を記載しています。是非参考にしてみてください。
コメント
[…] 「TensorflowLiteのModeldatahandlerの説明。」 […]