tensorflow 使用Python训练模型,js使用模型
这两天的项目需要用到Tensorflow.js来实现一个AI,尽管说Tensorflow.js本身是有训练模型的功能的,不过考虑到js这个东西加载资源要考虑跨域问题等种种因素。最终还是决定使用python的tensorflow来训练模型,然后利用js端来使用模型进行运算,那么关键问题就是:js如何加载python下训练的模型
首先我们用python写一段tensorflow的模型训练代码line.py
#coding=utf-8# import tensorflow as tf import numpy as np x_data=[[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]] #训练数据 y_data=[[0.0],[1.0],[1.0],[0.0]] #标签 x_test=[[0.0,1.0],[1.0,1.0]] #测试数据 xs=tf.placeholder(tf.float32,[None,2]) ys=tf.placeholder(tf.float32,[None,1]) #定义x和y的占位符作为将要输入神经网络的变量 #构建隐藏层,假设隐藏层有20个神经元 W1=tf.Variable(tf.random_normal([2,10])) B1=tf.Variable(tf.zeros([1,10])+0.1) out1=tf.nn.relu(tf.matmul(xs,W1)+B1) #构建输出层,假设输出层有一个神经元 W2=tf.Variable(tf.random_normal([10,1])) B2=tf.Variable(tf.zeros([1,1])+0.1) prediction=tf.add(tf.matmul(out1,W2),B2,name="model") #计算预测值和真实值之间的误差 loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1])) train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss) init=tf.global_variables_initializer() #初始化所有变量 sess=tf.Session() sess.run(init) for i in range(40): #训练10次 sess.run(train_step,feed_dict={xs:x_data,ys:y_data}) print(sess.run(loss,feed_dict={xs:x_data,ys:y_data})) #打印损失值 re=sess.run(prediction,feed_dict={xs:x_test}) print(re) for x in re: if x[0]>0.5: print(1) else: print(0) # 保存模型为saved_model tf.saved_model.simple_save(sess, "./saved_model",inputs={"x": xs, }, outputs={"model": prediction, })转换模型
tensorflowjs_converter --input_format=tf_saved_model \ --output_node_names="softmax" \ --saved_model_tags=serve ./saved_model \ ./web_model
web_model文件夹中包含
group1-shard1of1是转换后的模型文件
tensorflowjs_model.pb 为 tensorflow.js能识别的模型
weights_manifest.json 为 tensorflow.js能识别的模型参数文件
好的,接下来就是用js调用模型了
<!doctype html> <html lang="en"> <head> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script> </head> <body> <img style="display: none" id="cat" src="high-detail.jpg" width="224" height="224"> <script> const MODEL_URL = './tensorflowjs_model.pb' const WEIGHTS_URL = './weights_manifest.json' async function fun(){ const model = await tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL) const cs = tf.tensor([[1.0,1.0],[0.0,0.0]]) cs.print() model.predict(cs).print() } fun() </script> </body> </html>注意点:跨域问题,将模型文件与html放在同一个域名目录下,避免出现跨域问题
网友评论0