一步一步教你在浏览器用tensorflow训练手写数字并识别

今天我们来写一个运行在浏览器中的tensorflow.js的人工智能小例子,让计算机可以识别我们手写的数字
一、第一步我们需要用到minst的数据
async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
chunkSize);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);
resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});
const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest,
labelsRequest]);
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0,
IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0,
NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}
二、第二步我们来进行数据训练产生一个识别模型
function createConvModel() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [IMAGE_H, IMAGE_W, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({
poolSize: 2, strides: 2
}));
model.add(tf.layers.conv2d({
kernelSize: 3, filters: 32, activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({
poolSize: 2, strides: 2
}));
model.add(tf.layers.conv2d({
kernelSize: 3, filters: 32, activation: 'relu'
}));
model.add(tf.layers.flatten({}));
model.add(tf.layers.dense({
...点击查看剩余70%
网友评论0