一步一步教你在浏览器用tensorflow训练手写数字并识别
今天我们来写一个运行在浏览器中的tensorflow.js的人工智能小例子,让计算机可以识别我们手写的数字
一、第一步我们需要用到minst的数据
async load() { // Make a request for the MNIST sprited image. const img = new Image(); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const imgRequest = new Promise((resolve, reject) => { img.crossOrigin = ''; img.onload = () => { img.width = img.naturalWidth; img.height = img.naturalHeight; const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { const datasetBytesView = new Float32Array( datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize); ctx.drawImage( img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let j = 0; j < imageData.data.length / 4; j++) { // All channels hold an equal value since the image is grayscale, so // just read the red channel. datasetBytesView[j] = imageData.data[j * 4] / 255; } } this.datasetImages = new Float32Array(datasetBytesBuffer); resolve(); }; img.src = MNIST_IMAGES_SPRITE_PATH; }); const labelsRequest = fetch(MNIST_LABELS_PATH); const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Slice the the images and labels into train and test sets. this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); }
二、第二步我们来进行数据训练产生一个识别模型
function createConvModel() { const model = tf.sequential(); model.add(tf.layers.conv2d({ inputShape: [IMAGE_H, IMAGE_W, 1], kernelSize: 3, filters: 16, activation: 'relu' })); model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); model.add(tf.layers.flatten({})); model.add(tf.layers.dense({ units: 64, activation: 'relu' })); model.add(tf.layers.dense({ units: 10, activation: 'softmax' })); return model; }
三、第三步我们来用利用minst的数据来训练模型
async function train(model, onIteration) { logStatus('Training model...'); // Now that we've defined our model, we will define our optimizer. The // optimizer will be used to optimize our model's weight values during // training so that we can decrease our training loss and increase our // classification accuracy. // The learning rate defines the magnitude by which we update our weights each // training step. The higher the value, the faster our loss values converge, // but also the more likely we are to overshoot optimal parameters // when making an update. A learning rate that is too low will take too long // to find optimal (or good enough) weight parameters while a learning rate // that is too high may overshoot optimal parameters. Learning rate is one of // the most important hyperparameters to set correctly. Finding the right // value takes practice and is often best found empirically by trying many // values. const LEARNING_RATE = 0.01; // We are using rmsprop as our optimizer. // An optimizer is an iterative method for minimizing an loss function. // It tries to find the minimum of our loss function with respect to the // model's weight parameters. const optimizer = 'rmsprop'; // We compile our model by specifying an optimizer, a loss function, and a // list of metrics that we will use for model evaluation. Here we're using a // categorical crossentropy loss, the standard choice for a multi-class // classification problem like MNIST digits. // The categorical crossentropy loss is differentiable and hence makes // model training possible. But it is not amenable to easy interpretation // by a human. This is why we include a "metric", namely accuracy, which is // simply a measure of how many of the examples are classified correctly. // This metric is not differentiable and hence cannot be used as the loss // function of the model. model.compile({ optimizer, loss: 'categoricalCrossentropy', metrics: ['accuracy'], }); // Batch size is another important hyperparameter. It defines the number of // examples we group together, or batch, between updates to the model's // weights during training. A value that is too low will update weights using // too few examples and will not generalize well. Larger batch sizes require // more memory resources and aren't guaranteed to perform better. const batchSize = 320; // Leave out the last 15% of the training data for validation, to monitor // overfitting during training. const validationSplit = 0.15; // Get number of training epochs from the UI. const trainEpochs = getTrainEpochs(); // We'll keep a buffer of loss and accuracy values over time. let trainBatchCount = 0; const trainData = data.getTrainData(); const testData = data.getTestData(); const totalNumBatches = Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) * trainEpochs; // During the long-running fit() call for model training, we include // callbacks, so that we can plot the loss and accuracy values in the page // as the training progresses. let valAcc; await model.fit(trainData.xs, trainData.labels, { batchSize, validationSplit, epochs: trainEpochs, callbacks: { onBatchEnd: async (batch, logs) => { trainBatchCount++; logStatus( `Training... (` + `${(trainBatchCount / totalNumBatches * 100).toFixed(1)}%` + ` complete). To stop training, refresh or close page.`); plotLoss(trainBatchCount, logs.loss, 'train'); plotAccuracy(trainBatchCount, logs.acc, 'train'); if (onIteration && batch % 10 === 0) { onIteration('onBatchEnd', batch, logs); } await tf.nextFrame(); }, onEpochEnd: async (epoch, logs) => { valAcc = logs.val_acc; plotLoss(trainBatchCount, logs.val_loss, 'validation'); plotAccuracy(trainBatchCount, logs.val_acc, 'validation'); if (onIteration) { onIteration('onEpochEnd', epoch, logs); } await tf.nextFrame(); } } }); const testResult = model.evaluate(testData.xs, testData.labels); const testAccPercent = testResult[1].dataSync()[0] * 100; const finalValAccPercent = valAcc * 100; logStatus( `Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` + `Final test accuracy: ${testAccPercent.toFixed(1)}%`); }
四、模型训练好了我们来验证一下模型的准确性
const testExamples = 100; const examples = data.getTestData(testExamples); // Code wrapped in a tf.tidy() function callback will have their tensors freed // from GPU memory after execution without having to call dispose(). // The tf.tidy callback runs synchronously. tf.tidy(() => { const output = model.predict(examples.xs); // tf.argMax() returns the indices of the maximum values in the tensor along // a specific axis. Categorical classification tasks like this one often // represent classes as one-hot vectors. One-hot vectors are 1D vectors with // one element for each output class. All values in the vector are 0 // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The // output from model.predict() will be a probability distribution, so we use // argMax to get the index of the vector element that has the highest // probability. This is our prediction. // (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3) // dataSync() synchronously downloads the tf.tensor values from the GPU so // that we can use them in our normal CPU JavaScript code // (for a non-blocking version of this function, use data()). const axis = 1; const labels = Array.from(examples.labels.argMax(axis).dataSync()); const predictions = Array.from(output.argMax(axis).dataSync()); showTestResults(examples, predictions, labels);好了,一个简单的tensorflow.js编写的手写数字识别程序写好了
完整的代码地址:http://studio.bfw.wiki/Studio/Open/lang/html/id/15930025023507870045.html
网友评论0