一步一步教你在浏览器用tensorflow训练手写数字并识别

一步一步教你在浏览器用tensorflow训练手写数字并识别

一步一步教你在浏览器用tensorflow训练手写数字并识别

今天我们来写一个运行在浏览器中的tensorflow.js的人工智能小例子,让计算机可以识别我们手写的数字

一、第一步我们需要用到minst的数据

 async load() {
                // Make a request for the MNIST sprited image.
                const img = new Image();
                const canvas = document.createElement('canvas');
                const ctx = canvas.getContext('2d');
                const imgRequest = new Promise((resolve, reject) => {
                    img.crossOrigin = '';
                    img.onload = () => {
                        img.width = img.naturalWidth;
                        img.height = img.naturalHeight;

                        const datasetBytesBuffer =
                        new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

                        const chunkSize = 5000;
                        canvas.width = img.width;
                        canvas.height = chunkSize;

                        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
                            const datasetBytesView = new Float32Array(
                                datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
                                IMAGE_SIZE * chunkSize);
                            ctx.drawImage(
                                img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
                                chunkSize);

                            const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

                            for (let j = 0; j < imageData.data.length / 4; j++) {
                                // All channels hold an equal value since the image is grayscale, so
                                // just read the red channel.
                                datasetBytesView[j] = imageData.data[j * 4] / 255;
                            }
                        }
                        this.datasetImages = new Float32Array(datasetBytesBuffer);

                        resolve();
                    };
                    img.src = MNIST_IMAGES_SPRITE_PATH;
                });

                const labelsRequest = fetch(MNIST_LABELS_PATH);
                const [imgResponse, labelsResponse] =
                await Promise.all([imgRequest,
                    labelsRequest]);

                this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

                // Slice the the images and labels into train and test sets.
                this.trainImages =
                this.datasetImages.slice(0,
                    IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
                this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
                this.trainLabels =
                this.datasetLabels.slice(0,
                    NUM_CLASSES * NUM_TRAIN_ELEMENTS);
                this.testLabels =
                this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
            }

二、第二步我们来进行数据训练产生一个识别模型

 function createConvModel() {
            const model = tf.sequential();
            model.add(tf.layers.conv2d({
                inputShape: [IMAGE_H, IMAGE_W, 1],
                kernelSize: 3,
                filters: 16,
                activation: 'relu'
            }));
            model.add(tf.layers.maxPooling2d({
                poolSize: 2, strides: 2
            }));
            model.add(tf.layers.conv2d({
                kernelSize: 3, filters: 32, activation: 'relu'
            }));
            model.add(tf.layers.maxPooling2d({
                poolSize: 2, strides: 2
            }));
            model.add(tf.layers.conv2d({
                kernelSize: 3, filters: 32, activation: 'relu'
            }));
            model.add(tf.layers.flatten({}));
            model.add(tf.layers.dense({
                units: 64, activation: 'relu'
            }));
            model.add(tf.layers.dense({
                units: 10, activation: 'softmax'
            }));
            return model;
        }

三、第三步我们来用利用minst的数据来训练模型

       async function train(model, onIteration) {
            logStatus('Training model...');

            // Now that we've defined our model, we will define our optimizer. The
            // optimizer will be used to optimize our model's weight values during
            // training so that we can decrease our training loss and increase our
            // classification accuracy.

            // The learning rate defines the magnitude by which we update our weights each
            // training step. The higher the value, the faster our loss values converge,
            // but also the more likely we are to overshoot optimal parameters
            // when making an update. A learning rate that is too low will take too long
            // to find optimal (or good enough) weight parameters while a learning rate
            // that is too high may overshoot optimal parameters. Learning rate is one of
            // the most important hyperparameters to set correctly. Finding the right
            // value takes practice and is often best found empirically by trying many
            // values.
            const LEARNING_RATE = 0.01;

            // We are using rmsprop as our optimizer.
            // An optimizer is an iterative method for minimizing an loss function.
            // It tries to find the minimum of our loss function with respect to the
            // model's weight parameters.
            const optimizer = 'rmsprop';

            // We compile our model by specifying an optimizer, a loss function, and a
            // list of metrics that we will use for model evaluation. Here we're using a
            // categorical crossentropy loss, the standard choice for a multi-class
            // classification problem like MNIST digits.
            // The categorical crossentropy loss is differentiable and hence makes
            // model training possible. But it is not amenable to easy interpretation
            // by a human. This is why we include a "metric", namely accuracy, which is
            // simply a measure of how many of the examples are classified correctly.
            // This metric is not differentiable and hence cannot be used as the loss
            // function of the model.
            model.compile({
                optimizer,
                loss: 'categoricalCrossentropy',
                metrics: ['accuracy'],
            });

            // Batch size is another important hyperparameter. It defines the number of
            // examples we group together, or batch, between updates to the model's
            // weights during training. A value that is too low will update weights using
            // too few examples and will not generalize well. Larger batch sizes require
            // more memory resources and aren't guaranteed to perform better.
            const batchSize = 320;

            // Leave out the last 15% of the training data for validation, to monitor
            // overfitting during training.
            const validationSplit = 0.15;

            // Get number of training epochs from the UI.
            const trainEpochs = getTrainEpochs();

            // We'll keep a buffer of loss and accuracy values over time.
            let trainBatchCount = 0;

            const trainData = data.getTrainData();
            const testData = data.getTestData();

            const totalNumBatches =
            Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) *
            trainEpochs;

            // During the long-running fit() call for model training, we include
            // callbacks, so that we can plot the loss and accuracy values in the page
            // as the training progresses.
            let valAcc;
            await model.fit(trainData.xs, trainData.labels, {
                batchSize,
                validationSplit,
                epochs: trainEpochs,
                callbacks: {
                    onBatchEnd: async (batch, logs) => {
                        trainBatchCount++;
                        logStatus(
                            `Training... (` +
                            `${(trainBatchCount / totalNumBatches * 100).toFixed(1)}%` +
                            ` complete). To stop training, refresh or close page.`);
                        plotLoss(trainBatchCount, logs.loss, 'train');
                        plotAccuracy(trainBatchCount, logs.acc, 'train');
                        if (onIteration && batch % 10 === 0) {
                            onIteration('onBatchEnd', batch, logs);
                        }
                        await tf.nextFrame();
                    },
                    onEpochEnd: async (epoch, logs) => {
                        valAcc = logs.val_acc;
                        plotLoss(trainBatchCount, logs.val_loss, 'validation');
                        plotAccuracy(trainBatchCount, logs.val_acc, 'validation');
                        if (onIteration) {
                            onIteration('onEpochEnd', epoch, logs);
                        }
                        await tf.nextFrame();
                    }
                }
            });
            
            const testResult = model.evaluate(testData.xs,
                testData.labels);
            const testAccPercent = testResult[1].dataSync()[0] * 100;
            const finalValAccPercent = valAcc * 100;
            logStatus(
                `Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` +
                `Final test accuracy: ${testAccPercent.toFixed(1)}%`);
        }

四、模型训练好了我们来验证一下模型的准确性

            const testExamples = 100;
            const examples = data.getTestData(testExamples);

            // Code wrapped in a tf.tidy() function callback will have their tensors freed
            // from GPU memory after execution without having to call dispose().
            // The tf.tidy callback runs synchronously.
            tf.tidy(() => {
                const output = model.predict(examples.xs);

                // tf.argMax() returns the indices of the maximum values in the tensor along
                // a specific axis. Categorical classification tasks like this one often
                // represent classes as one-hot vectors. One-hot vectors are 1D vectors with
                // one element for each output class. All values in the vector are 0
                // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The
                // output from model.predict() will be a probability distribution, so we use
                // argMax to get the index of the vector element that has the highest
                // probability. This is our prediction.
                // (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3)
                // dataSync() synchronously downloads the tf.tensor values from the GPU so
                // that we can use them in our normal CPU JavaScript code
                // (for a non-blocking version of this function, use data()).
                const axis = 1;
                const labels = Array.from(examples.labels.argMax(axis).dataSync());
                const predictions = Array.from(output.argMax(axis).dataSync());

                showTestResults(examples, predictions, labels);

好了,一个简单的tensorflow.js编写的手写数字识别程序写好了

完整的代码地址:http://studio.bfw.wiki/Studio/Open/lang/html/id/15930025023507870045.html


{{collectdata}}

网友评论0