一步一步教你在浏览器用tensorflow训练手写数字并识别

一步一步教你在浏览器用tensorflow训练手写数字并识别

一步一步教你在浏览器用tensorflow训练手写数字并识别

今天我们来写一个运行在浏览器中的tensorflow.js的人工智能小例子,让计算机可以识别我们手写的数字

一、第一步我们需要用到minst的数据

 async load() {
                // Make a request for the MNIST sprited image.
                const img = new Image();
                const canvas = document.createElement('canvas');
                const ctx = canvas.getContext('2d');
                const imgRequest = new Promise((resolve, reject) => {
                    img.crossOrigin = '';
                    img.onload = () => {
                        img.width = img.naturalWidth;
                        img.height = img.naturalHeight;

                        const datasetBytesBuffer =
                        new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

                        const chunkSize = 5000;
                        canvas.width = img.width;
                        canvas.height = chunkSize;

                        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
                            const datasetBytesView = new Float32Array(
                                datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
                                IMAGE_SIZE * chunkSize);
                            ctx.drawImage(
                                img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
                                chunkSize);

                            const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

                            for (let j = 0; j < imageData.data.length / 4; j++) {
                                // All channels hold an equal value since the image is grayscale, so
                                // just read the red channel.
                                datasetBytesView[j] = imageData.data[j * 4] / 255;
                            }
                        }
                        this.datasetImages = new Float32Array(datasetBytesBuffer);

                        resolve();
                    };
                    img.src = MNIST_IMAGES_SPRITE_PATH;
                });

                const labelsRequest = fetch(MNIST_LABELS_PATH);
                const [imgResponse, labelsResponse] =
                await Promise.all([imgRequest,
                    labelsRequest]);

                this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

                // Slice the the images and labels into train and test sets.
                this.trainImages =
                this.datasetImages.slice(0,
                    IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
                this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
                this.trainLabels =
                this.datasetLabels.slice(0,
                    NUM_CLASSES * NUM_TRAIN_ELEMENTS);
                this.testLabels =
                this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
            }

二、第二步我们来进行数据训练产生一个识别模型

 function createConvModel() {
            const model = tf.sequential();
            model.add(tf.layers.conv2d({
                inputShape: [IMAGE_H, IMAGE_W, 1],
                kernelSize: 3,
                filters: 16,
                activation: 'relu'
            }));
            model.add(tf.layers.maxPooling2d({
                poolSize: 2, strides: 2
            }));
            model.add(tf.layers.conv2d({
                kernelSize: 3, filters: 32, activation: 'relu'
            }));
            model.add(tf.layers.maxPooling2d({
                poolSize: 2, strides: 2
            }));
            model.add(tf.layers.conv2d({
                kernelSize: 3, filters: 32, activation: 'relu'
            }));
            model.add(tf.layers.flatten({}));
            model.add(tf.layers.dense({
                ...

点击查看剩余70%

{{collectdata}}

网友评论0