Train a model to classify and localize triangles and rectangles
Description
This example page shows inference with a pretrained object-detection model that can classify and localize (i.e., give the position of) target shapes in simple synthesized scenes. The training happens in tfjs-node (see train.js in the source folder).
The target objects that we want to detect are triangles and rectangles, in the background are line segments and circles that the model should ignore. Each synthetic scene contains only one target object. Each inference by the model generates the bounding box of the target object (show as the blue box), which can be compared with the true bounding box (show as the red box). In addition, the inference output contains the class of the shape (triangle vs. rectangle).
// Name prefixes of layers that will be unfrozen during fine-tuning. const topLayerGroupNames = ['conv_pw_9', 'conv_pw_10', 'conv_pw_11']; // Name of the layer that will become the top layer of the truncated base. const topLayerName = `${topLayerGroupNames[topLayerGroupNames.length - 1]}_relu`;Head model on the top of the "conv_pw_11_relu" [null,14,14,128] layer has two dence layers with 5018805 parameters!
function buildNewHead(inputShape) { const newHead = tf.sequential(); newHead.add(tf.layers.flatten({inputShape})); newHead.add(tf.layers.dense({units: 200, activation: 'relu'})); // Five output units: // - The first is a shape indictor: predicts whether the target // shape is a triangle or a rectangle. // - The remaining four units are for bounding-box prediction: // [left, right, top, bottom] in the unit of pixels. newHead.add(tf.layers.dense({units: 5})); return newHead; }It is fairly straightforward approach.