tensorflow 使用Python训练模型,js使用模型

tensorflow 使用Python训练模型,js使用模型

tensorflow 使用Python训练模型,js使用模型

这两天的项目需要用到Tensorflow.js来实现一个AI,尽管说Tensorflow.js本身是有训练模型的功能的,不过考虑到js这个东西加载资源要考虑跨域问题等种种因素。最终还是决定使用python的tensorflow来训练模型,然后利用js端来使用模型进行运算,那么关键问题就是:js如何加载python下训练的模型

首先我们用python写一段tensorflow的模型训练代码line.py

#coding=utf-8#
import tensorflow as tf
import numpy as np
x_data=[[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]]	#训练数据
y_data=[[0.0],[1.0],[1.0],[0.0]]	#标签
x_test=[[0.0,1.0],[1.0,1.0]]	#测试数据
xs=tf.placeholder(tf.float32,[None,2])
ys=tf.placeholder(tf.float32,[None,1])	#定义x和y的占位符作为将要输入神经网络的变量
 
#构建隐藏层,假设隐藏层有20个神经元
W1=tf.Variable(tf.random_normal([2,10]))
B1=tf.Variable(tf.zeros([1,10])+0.1)
out1=tf.nn.relu(tf.matmul(xs,W1)+B1)
#构建输出层,假设输出层有一个神经元
W2=tf.Variable(tf.random_normal([10,1]))
B2=tf.Variable(tf.zeros([1,1])+0.1)
prediction=tf.add(tf.matmul(out1,W2),B2,name="model")
#计算预测值和真实值之间的误差
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
 
init=tf.global_variables_initializer()	#初始化所有变量
sess=tf.Session()
sess.run(init)
 
for i in range(40):	#训练10次
	sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
	print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))	#打印损失值
re=sess.run(prediction,feed_dict={xs:x_test})
print(re)
for x in re:   
	if x[0]>0.5:
		print(1)
	else:
		print(0)
# 保存模型为saved_model
tf.saved_model.simple_save(sess, "./saved_model",inputs={"x": xs, }, outputs={"model": prediction, })

转换模型

tensorflowjs_converter --input_format=tf_saved_model \
  --output_node_names="softmax" \
  --saved_model_tags=serve ./saved_model \
  ./web_model

tensorflow 使用Python训练模型,js使用模型

web_model文件夹中包含

tensorflow 使用Python训练模型,js使用模型

group1-shard1of1是转换后的模型文件

tensorflowjs_model.pb 为 tensorflow.js能识别的模型

weights_manifest.json 为 tensorflow.js能识别的模型参数文件

好的,接下来就是用js调用模型了

<!doctype html>
<html lang="en">
<head>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
</head>
<body>
  <img style="display: none" id="cat" src="high-detail.jpg" width="224" height="224">
 
  <script>
const MODEL_URL = './tensorflowjs_model.pb'
const WEIGHTS_URL = './weights_manifest.json'
async function fun(){
    const model = await tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL)
      const cs = tf.tensor([[1.0,1.0],[0.0,0.0]])
	cs.print()
	model.predict(cs).print()
}
fun()
  </script>
</body>
</html>

注意点:跨域问题,将模型文件与html放在同一个域名目录下,避免出现跨域问题

tensorflow 使用Python训练模型,js使用模型

{{collectdata}}

网友评论0