uni memo

tensorflowjsのpre-trained modelを実行してみる

tensorflowjsのpre-trained modelを実行してみる

tensorflowjsには学習済みのモデルがいくつか公開されている。その中のmobilenetについてreactでデモページを作って試してみた

環境

  • node: 8.11.2
  • @tensorflow-models/mobilenet: 0.2.2
  • create-react-app: 1.1.4

reactのテンプレートにはcreate-react-appを使っている

フォルダ構成は↓のようにした

/src/components/ #component
    /assets/ #image file
    /App.js
    /index.js

ソースコード

概ねBuild software better, togetherのreadme通りにした

packageをinstallする

npm install --save @tensorflow-models/mobilenet

デモ用ページ

デモ用のページを作成

./components/ImageNetPage.js
import React, { Component } from 'react'

import * as mobilenet from '@tensorflow-models/mobilenet'

import CatImg from "./assets/cat1.jpeg"
import AirplainImg from "./assets/airplain.jpeg"

import './components/imagenetpage.css';


export default class ImageNetPage extends Component {
  constructor(props) {
    super(props)

    this.state = {
      // model loading flag読み込み中は画像をクリックできなくする
      modelLoading: true, 
      preds: []
    }
  }

  async componentDidMount() {
    this.model = await mobilenet.load()

    this.setState({
      modelLoading: false
    })
  }

  // 画像をクリックしたら予測を行い、結果をstateにセット
  onClick(e) {
    let img = e.target

    this.model.classify(img).then(preds => {
      this.setState({
        preds: preds
      })
    })
  }

  render() {
    // モデルを読み込んでいるときはdisableクラスを画像に追加
    let disable = this.state.modelLoading ? 'disable' : ''

    return (
      <div>
        <h2>predict image class name</h2>
        <h3>click image</h3>
        <img src={CatImg} className={`Cat1 ${disable}`} alt="cat" onClick={this.onClick.bind(this)} />
        <img src={AirplainImg} className={`Airplain ${disable}`} alt="airplain" onClick={this.onClick.bind(this)} />
        {
          this.state.preds.map((pred, index) => {
            let name = pred.className
            let prob = Number(pred.probability).toFixed(5)
            return (
              <div key={index}>
                <p>class: {name}</p>
                <p>probability: {prob}</p>
              </div>
            )
          })
        }
      </div>
    )
  }
}
imagenetpage.css
.disable {
  pointer-events: none;
}

ページ読み込み部分

App.js
import React, { Component } from 'react'

import ImageNetPage from './components/ImageNetPage'

import "./App.css"

class App extends Component {
  render() {
    return (
      <div>
        <ImageNetPage />
      </div>
    );
  }
}

export default App;
index.js
import React from 'react';
import ReactDOM from 'react-dom';
import App from './App';
import registerServiceWorker from './registerServiceWorker';
import './index.css';

ReactDOM.render(<App />, document.getElementById('root'));
registerServiceWorker();

以下のようなページが出来上がる

ちなみにクラス数は1000あるらしい

https://github.com/tensorflow/tfjs-models/blob/master/mobilenet/src/imagenet_classes.ts

参考

2024, Built with Gatsby. This site uses Google Analytics.