tensorflowjsには学習済みのモデルがいくつか公開されている。その中のmobilenetについてreactでデモページを作って試してみた
環境
- node: 8.11.2
- @tensorflow-models/mobilenet: 0.2.2
- create-react-app: 1.1.4
reactのテンプレートにはcreate-react-appを使っている
フォルダ構成は↓のようにした
/src/components/ #component
/assets/ #image file
/App.js
/index.js
ソースコード
概ねBuild software better, togetherのreadme通りにした
packageをinstallする
npm install --save @tensorflow-models/mobilenet
デモ用ページ
デモ用のページを作成
./components/ImageNetPage.js
import React, { Component } from 'react'
import * as mobilenet from '@tensorflow-models/mobilenet'
import CatImg from "./assets/cat1.jpeg"
import AirplainImg from "./assets/airplain.jpeg"
import './components/imagenetpage.css';
export default class ImageNetPage extends Component {
constructor(props) {
super(props)
this.state = {
// model loading flag読み込み中は画像をクリックできなくする
modelLoading: true,
preds: []
}
}
async componentDidMount() {
this.model = await mobilenet.load()
this.setState({
modelLoading: false
})
}
// 画像をクリックしたら予測を行い、結果をstateにセット
onClick(e) {
let img = e.target
this.model.classify(img).then(preds => {
this.setState({
preds: preds
})
})
}
render() {
// モデルを読み込んでいるときはdisableクラスを画像に追加
let disable = this.state.modelLoading ? 'disable' : ''
return (
<div>
<h2>predict image class name</h2>
<h3>click image</h3>
<img src={CatImg} className={`Cat1 ${disable}`} alt="cat" onClick={this.onClick.bind(this)} />
<img src={AirplainImg} className={`Airplain ${disable}`} alt="airplain" onClick={this.onClick.bind(this)} />
{
this.state.preds.map((pred, index) => {
let name = pred.className
let prob = Number(pred.probability).toFixed(5)
return (
<div key={index}>
<p>class: {name}</p>
<p>probability: {prob}</p>
</div>
)
})
}
</div>
)
}
}
imagenetpage.css
.disable {
pointer-events: none;
}
ページ読み込み部分
App.js
import React, { Component } from 'react'
import ImageNetPage from './components/ImageNetPage'
import "./App.css"
class App extends Component {
render() {
return (
<div>
<ImageNetPage />
</div>
);
}
}
export default App;
index.js
import React from 'react';
import ReactDOM from 'react-dom';
import App from './App';
import registerServiceWorker from './registerServiceWorker';
import './index.css';
ReactDOM.render(<App />, document.getElementById('root'));
registerServiceWorker();
以下のようなページが出来上がる
ちなみにクラス数は1000あるらしい
https://github.com/tensorflow/tfjs-models/blob/master/mobilenet/src/imagenet_classes.ts