一步一步教你在浏览器用tensorflow训练手写数字并识别
今天我们来写一个运行在浏览器中的tensorflow.js的人工智能小例子,让计算机可以识别我们手写的数字
一、第一步我们需要用到minst的数据
async load() { // Make a request for the MNIST sprited image. const img = new Image(); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const imgRequest = new Promise((resolve, reject) => { img.crossOrigin = ''; img.onload = () => { img.width = img.naturalWidth; img.height = img.naturalHeight; const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { const datasetBytesView = new Float32Array( datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize); ctx.drawImage( img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let j = 0; j < imageData.data.length / 4; j++) { // All channels hold an equal value since the image is grayscale, so // just read the red channel. datasetBytesView[j] = imageData.data[j * 4] / 255; } } this.datasetImages = new Float32Array(datasetBytesBuffer); resolve(); }; img.src = MNIST_IMAGES_SPRITE_PATH; }); const labelsRequest = fetch(MNIST_LABELS_PATH); const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Slice the the images and labels into train and test sets. this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); }
二、第二步我们来进行数据训练产生一个识别模型
function createConvModel() { const model = tf.sequential(); model.add(tf.layers.conv2d({ inputShape: [IMAGE_H, IMAGE_W, 1], kernelSize: 3, filters: 16, activation: 'relu' })); model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); model.add(tf.layers.flatten({})); model.add(tf.layers.dense({ ...
点击查看剩余70%
网友评论0