TinyML: Getting Started with TensorFlow Lite for Microcontrollers
2020-07-06 | By ShawnHymel
License: Attribution
TensorFlow is a popular open source software library (developed by Google) for performing machine learning tasks. A subset of this library is TensorFlow Lite for Microcontrollers, which allows us to run inference on microcontrollers. Note that “inference” is just using the model to make predictions, classifications, or decisions. It does not include training the model.
Because machine learning (especially neural networks and deep learning) is computationally expensive, TensorFlow Lite for Microcontrollers requires you to use a 32-bit processor, such as an ARM Cortex-M or ESP32. Also note that the library is (mostly) written in C++, so you will need to use a C++ compiler.
If you would like to see this tutorial in video form, please check out this video:
Overview
As microcontrollers are limited in resources, we generally need to perform model training on our computer (or remote server) first. Once we have a model, we convert it to a FlatBuffer (.tflite) file and then convert that model to a constant C array we include in our firmware program.
On the microcontroller, we run the TensorFlow Lite for Microcontollers library, which uses our model to perform inference. For example, let’s say we trained a model to classify if there is a cat in a photo. If we use that model in our microcontroller, we can feed it unseen data (i.e. photos), and it will tell us whether or not it thinks there is a cat in that photo
Model Training
We are going to use a model trained in a previous tutorial. Please follow the steps in this tutorial to generate a model in .h5, .tflite, and .h formats. Download all three files to your computer.
Development Environment
This tutorial will show you how to generate source code files in TensorFlow Lite that you can use as a library in any microcontroller build system (Arduino, make, Eclipse, etc.). However, I am going to specifically show you how to include this library in STM32CubeIDE. There are a few reasons for this: I am most familiar with this IDE right now, and we can do a side-by-side comparison of TensorFlow Lite and the STM32 X-Cube-AI library in the next tutorial.
I will be running this demo on a Nucelo-L432KC development board.
See this video to get started with STM32CubeIDE.
Generate TensorFlow Lite File Structure
The creators of TensorFlow want you to use the Make build tool to generate a number of example projects that you use as templates for your microcontroller. While this can work well, I wanted to show how you can modify the procedure to use TensorFlow Lite as a library instead of a starter project.
Note that you will need Linux or macOS for this next part. I have not gotten the auto-generation of TensorFlow Lite projects working in Windows. I was able to do this part on a Raspberry Pi.
In a new terminal, install the following:
sudo apt update
sudo apt install make git python3.7 python3-pip zip
Note that you may need to alias python3 and pip3 for some of the TensorFlow Lite Make commands to work. In ~/.bashrc, add the following to the end:
alias python=python3
alias pip=pip3
We’ll get the newest version of the TensorFlow repository (use a depth of 1 to get only the latest revision):
git clone --depth 1 https://github.com/tensorflow/tensorflow.git
Navigate into the tensorflow directory and run the Makefile in the TensorFlow Lite for Microcontrollers directory:
cd tensorflow
make -f tensorflow/lite/micro/tools/make/Makefile TAGS=”portable_optimized” generate_non_kernel_projects
This will take a few minutes, so be patient. It generates a number of example projects and source code for you to use as a starting point.
Create Project
In STM32CubeIDE, create a new STM32 project for your Nucleo-L432KC board. Give your project a name and make sure C++ is selected as the target language.
Enable timer 16 (TIM16) and give it a prescaler of (80 - 1) so that it will tick once per microsecond and a reload value of (65536 - 1).
In Clock Configuration, select HSI as your PLL Source and type “80” into the HCLK box. Press ‘enter,’ and the CubeMX software should automatically configure all of the system clocks and prescalers to give you a system clock of 80 MHz.
Rename main.c to main.cpp.
Copy Model and TensorFlow Lite Source Code Files to Your Project
Find the sine_model.h file that you generated and downloaded during the training step. Copy that file to <your_project_directory>/Core/Inc. It contains a byte array of your neural network model in FlatBuffer format.
Go into <tensorflow_repo>/tensorflow/lite/micro/tools/make/gen/<os_cpu>/prj/hello_world/make and copy tensorflow and third_party. Note that <os_cpu> will change, depending on where you ran the make command (for example, it is linux_armv7l for me, as I ran make on a Raspberry Pi).
Paste them in <your_project_directory>/tensorflow_lite, creating the tensorflow_lite directory as necessary.
Go into <your_project_directory>/tensorflow_lite/tensorflow/lite/micro and delete the examples folder, as it contains a template main.c application that we do not want in our project. Feel free to examine it to see how TensorFlow recommends you create firmware projects.
In STM32CubeIDE, right-click on your project and select Refresh. You should now see your model file and tensorflow_lite directory appear in your project.
Include Headers and Source in Build Process
Even though the source files are in our project, we still need to tell our IDE to include them in the build process. Go to Project > Properties. In that window, go to C/C++ General > Paths and Symbols > Includes tab > GNU C. Click Add. In the pop-up, click Workspace. Select the tensorflow_lite directory in your project. Check Add to all configurations and Add to all languages.
Repeat this process to add the following directories in tensorflow_lite/third_party:
- flatbuffers/include
- gemmlowp
-
ruy
Head to the Source Location tab and add <your_project_directory>/tensorflow_lite to both the Debug and Release configurations.
Click Apply and Close and Yes (if asked to rebuild the index).
Update Debug Code
We need to make one change to the TensorFlow library so that it will support debugging out over a serial port. Open <your_project_directory>/tensorflow_lite/tensorflow/lite/micro/debug_log.cc.
Update the code to the following:
#include "tensorflow/lite/micro/debug_log.h"
//#include <cstdio>
//
//extern "C" void DebugLog(const char* s) { fprintf(stderr, "%s", s); }
extern "C" void __attribute__((weak)) DebugLog(const char* s) {
// To be implemented by user
}
We comment out the original implementation of DebugLog (as we are not supporting fprintf) and add our own with the weak attribute. This allows us to provide the actual implementation in main.cpp, where we can use the STM32 HAL to output debugging information over UART.
Save and close the file.
Write Your Main Program
Open main.cpp and add the following sections between the user header guards (e.g. /* USER CODE BEGIN … */). Note that the auto-generated code may change if you are using a different microcontroller or development board. Also note the custom implementation of DebugLog near the bottom, which overrides the definition in debug_log.cc.
/* USER CODE BEGIN Header */
/**
******************************************************************************
* @file : main.c
* @brief : Main program body
******************************************************************************
* @attention
*
* <h2><center>© Copyright (c) 2020 STMicroelectronics.
* All rights reserved.</center></h2>
*
* This software component is licensed by ST under BSD 3-Clause license,
* the "License"; You may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
* opensource.org/licenses/BSD-3-Clause
*
******************************************************************************
*/
/* USER CODE END Header */
/* Includes ------------------------------------------------------------------*/
#include "main.h"
/* Private includes ----------------------------------------------------------*/
/* USER CODE BEGIN Includes */
#include <string.h>
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/version.h"
#include "sine_model.h"
/* USER CODE END Includes */
/* Private typedef -----------------------------------------------------------*/
/* USER CODE BEGIN PTD */
/* USER CODE END PTD */
/* Private define ------------------------------------------------------------*/
/* USER CODE BEGIN PD */
/* USER CODE END PD */
/* Private macro -------------------------------------------------------------*/
/* USER CODE BEGIN PM */
/* USER CODE END PM */
/* Private variables ---------------------------------------------------------*/
TIM_HandleTypeDef htim16;
UART_HandleTypeDef huart2;
/* USER CODE BEGIN PV */
// TFLite globals
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* model_input = nullptr;
TfLiteTensor* model_output = nullptr;
// Create an area of memory to use for input, output, and other TensorFlow
// arrays. You'll need to adjust this by compiling, running, and looking
// for errors.
constexpr int kTensorArenaSize = 2 * 1024;
__attribute__((aligned(16)))uint8_t tensor_arena[kTensorArenaSize];
} // namespace
/* USER CODE END PV */
/* Private function prototypes -----------------------------------------------*/
void SystemClock_Config(void);
static void MX_GPIO_Init(void);
static void MX_USART2_UART_Init(void);
static void MX_TIM16_Init(void);
/* USER CODE BEGIN PFP */
/* USER CODE END PFP */
/* Private user code ---------------------------------------------------------*/
/* USER CODE BEGIN 0 */
/* USER CODE END 0 */
/**
* @brief The application entry point.
* @retval int
*/
int main(void)
{
/* USER CODE BEGIN 1 */
char buf[50];
int buf_len = 0;
TfLiteStatus tflite_status;
uint32_t num_elements;
uint32_t timestamp;
float y_val;
/* USER CODE END 1 */
/* MCU Configuration--------------------------------------------------------*/
/* Reset of all peripherals, Initializes the Flash interface and the Systick. */
HAL_Init();
/* USER CODE BEGIN Init */
/* USER CODE END Init */
/* Configure the system clock */
SystemClock_Config();
/* USER CODE BEGIN SysInit */
/* USER CODE END SysInit */
/* Initialize all configured peripherals */
MX_GPIO_Init();
MX_USART2_UART_Init();
MX_TIM16_Init();
/* USER CODE BEGIN 2 */
// Start timer/counter
HAL_TIM_Base_Start(&htim16);
// Set up logging (modify tensorflow/lite/micro/debug_log.cc)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = µ_error_reporter;
// Say something to test error reporter
error_reporter->Report("STM32 TensorFlow Lite test");
// Map the model into a usable data structure
model = tflite::GetModel(sine_model);
if (model->version() != TFLITE_SCHEMA_VERSION)
{
error_reporter->Report("Model version does not match Schema");
while(1);
}
// Pull in only needed operations (should match NN layers). Template parameter
// <n> is number of ops to be added. Available ops:
// tensorflow/lite/micro/kernels/micro_ops.h
static tflite::MicroMutableOpResolver<1> micro_op_resolver;
// Add dense neural network layer operation
tflite_status = micro_op_resolver.AddBuiltin(
tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED());
if (tflite_status != kTfLiteOk)
{
error_reporter->Report("Could not add FULLY CONNECTED op");
while(1);
}
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
tflite_status = interpreter->AllocateTensors();
if (tflite_status != kTfLiteOk)
{
error_reporter->Report("AllocateTensors() failed");
while(1);
}
// Assign model input and output buffers (tensors) to pointers
model_input = interpreter->input(0);
model_output = interpreter->output(0);
// Get number of elements in input tensor
num_elements = model_input->bytes / sizeof(float);
buf_len = sprintf(buf, "Number of input elements: %lu\r\n", num_elements);
HAL_UART_Transmit(&huart2, (uint8_t *)buf, buf_len, 100);
/* USER CODE END 2 */
/* Infinite loop */
/* USER CODE BEGIN WHILE */
while (1)
{
// Fill input buffer (use test value)
for (uint32_t i = 0; i < num_elements; i++)
{
model_input->data.f[i] = 2.0f;
}
// Get current timestamp
timestamp = htim16.Instance->CNT;
// Run inference
tflite_status = interpreter->Invoke();
if (tflite_status != kTfLiteOk)
{
error_reporter->Report("Invoke failed");
}
// Read output (predicted y) of neural network
y_val = model_output->data.f[0];
// Print output of neural network along with inference time (microseconds)
buf_len = sprintf(buf,
"Output: %f | Duration: %lu\r\n",
y_val,
htim16.Instance->CNT - timestamp);
HAL_UART_Transmit(&huart2, (uint8_t *)buf, buf_len, 100);
// Wait before doing it again
HAL_Delay(500);
/* USER CODE END WHILE */
/* USER CODE BEGIN 3 */
}
/* USER CODE END 3 */
}
/**
* @brief System Clock Configuration
* @retval None
*/
void SystemClock_Config(void)
{
RCC_OscInitTypeDef RCC_OscInitStruct = {0};
RCC_ClkInitTypeDef RCC_ClkInitStruct = {0};
RCC_PeriphCLKInitTypeDef PeriphClkInit = {0};
/** Initializes the CPU, AHB and APB busses clocks
*/
RCC_OscInitStruct.OscillatorType = RCC_OSCILLATORTYPE_HSI;
RCC_OscInitStruct.HSIState = RCC_HSI_ON;
RCC_OscInitStruct.HSICalibrationValue = RCC_HSICALIBRATION_DEFAULT;
RCC_OscInitStruct.PLL.PLLState = RCC_PLL_ON;
RCC_OscInitStruct.PLL.PLLSource = RCC_PLLSOURCE_HSI;
RCC_OscInitStruct.PLL.PLLM = 1;
RCC_OscInitStruct.PLL.PLLN = 10;
RCC_OscInitStruct.PLL.PLLP = RCC_PLLP_DIV7;
RCC_OscInitStruct.PLL.PLLQ = RCC_PLLQ_DIV2;
RCC_OscInitStruct.PLL.PLLR = RCC_PLLR_DIV2;
if (HAL_RCC_OscConfig(&RCC_OscInitStruct) != HAL_OK)
{
Error_Handler();
}
/** Initializes the CPU, AHB and APB busses clocks
*/
RCC_ClkInitStruct.ClockType = RCC_CLOCKTYPE_HCLK|RCC_CLOCKTYPE_SYSCLK
|RCC_CLOCKTYPE_PCLK1|RCC_CLOCKTYPE_PCLK2;
RCC_ClkInitStruct.SYSCLKSource = RCC_SYSCLKSOURCE_PLLCLK;
RCC_ClkInitStruct.AHBCLKDivider = RCC_SYSCLK_DIV1;
RCC_ClkInitStruct.APB1CLKDivider = RCC_HCLK_DIV1;
RCC_ClkInitStruct.APB2CLKDivider = RCC_HCLK_DIV1;
if (HAL_RCC_ClockConfig(&RCC_ClkInitStruct, FLASH_LATENCY_4) != HAL_OK)
{
Error_Handler();
}
PeriphClkInit.PeriphClockSelection = RCC_PERIPHCLK_USART2;
PeriphClkInit.Usart2ClockSelection = RCC_USART2CLKSOURCE_PCLK1;
if (HAL_RCCEx_PeriphCLKConfig(&PeriphClkInit) != HAL_OK)
{
Error_Handler();
}
/** Configure the main internal regulator output voltage
*/
if (HAL_PWREx_ControlVoltageScaling(PWR_REGULATOR_VOLTAGE_SCALE1) != HAL_OK)
{
Error_Handler();
}
}
/**
* @brief TIM16 Initialization Function
* @param None
* @retval None
*/
static void MX_TIM16_Init(void)
{
/* USER CODE BEGIN TIM16_Init 0 */
/* USER CODE END TIM16_Init 0 */
/* USER CODE BEGIN TIM16_Init 1 */
/* USER CODE END TIM16_Init 1 */
htim16.Instance = TIM16;
htim16.Init.Prescaler = 80 - 1;
htim16.Init.CounterMode = TIM_COUNTERMODE_UP;
htim16.Init.Period = 65536 - 1;
htim16.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
htim16.Init.RepetitionCounter = 0;
htim16.Init.AutoReloadPreload = TIM_AUTORELOAD_PRELOAD_DISABLE;
if (HAL_TIM_Base_Init(&htim16) != HAL_OK)
{
Error_Handler();
}
/* USER CODE BEGIN TIM16_Init 2 */
/* USER CODE END TIM16_Init 2 */
}
/**
* @brief USART2 Initialization Function
* @param None
* @retval None
*/
static void MX_USART2_UART_Init(void)
{
/* USER CODE BEGIN USART2_Init 0 */
/* USER CODE END USART2_Init 0 */
/* USER CODE BEGIN USART2_Init 1 */
/* USER CODE END USART2_Init 1 */
huart2.Instance = USART2;
huart2.Init.BaudRate = 115200;
huart2.Init.WordLength = UART_WORDLENGTH_8B;
huart2.Init.StopBits = UART_STOPBITS_1;
huart2.Init.Parity = UART_PARITY_NONE;
huart2.Init.Mode = UART_MODE_TX_RX;
huart2.Init.HwFlowCtl = UART_HWCONTROL_NONE;
huart2.Init.OverSampling = UART_OVERSAMPLING_16;
huart2.Init.OneBitSampling = UART_ONE_BIT_SAMPLE_DISABLE;
huart2.AdvancedInit.AdvFeatureInit = UART_ADVFEATURE_NO_INIT;
if (HAL_UART_Init(&huart2) != HAL_OK)
{
Error_Handler();
}
/* USER CODE BEGIN USART2_Init 2 */
/* USER CODE END USART2_Init 2 */
}
/**
* @brief GPIO Initialization Function
* @param None
* @retval None
*/
static void MX_GPIO_Init(void)
{
GPIO_InitTypeDef GPIO_InitStruct = {0};
/* GPIO Ports Clock Enable */
__HAL_RCC_GPIOC_CLK_ENABLE();
__HAL_RCC_GPIOA_CLK_ENABLE();
__HAL_RCC_GPIOB_CLK_ENABLE();
/*Configure GPIO pin Output Level */
HAL_GPIO_WritePin(LD3_GPIO_Port, LD3_Pin, GPIO_PIN_RESET);
/*Configure GPIO pin : LD3_Pin */
GPIO_InitStruct.Pin = LD3_Pin;
GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP;
GPIO_InitStruct.Pull = GPIO_NOPULL;
GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_LOW;
HAL_GPIO_Init(LD3_GPIO_Port, &GPIO_InitStruct);
}
/* USER CODE BEGIN 4 */
// Custom implementation of DebugLog from TensorFlow
extern "C" void DebugLog(const char* s)
{
HAL_UART_Transmit(&huart2, (uint8_t *)s, strlen(s), 100);
}
/* USER CODE END 4 */
/**
* @brief This function is executed in case of error occurrence.
* @retval None
*/
void Error_Handler(void)
{
/* USER CODE BEGIN Error_Handler_Debug */
/* User can add his own implementation to report the HAL error return state */
/* USER CODE END Error_Handler_Debug */
}
#ifdef USE_FULL_ASSERT
/**
* @brief Reports the name of the source file and the source line number
* where the assert_param error has occurred.
* @param file: pointer to the source file name
* @param line: assert_param error line source number
* @retval None
*/
void assert_failed(uint8_t *file, uint32_t line)
{
/* USER CODE BEGIN 6 */
/* User can add his own implementation to report the file name and line number,
tex: printf("Wrong parameters value: file %s on line %d\r\n", file, line) */
/* USER CODE END 6 */
}
#endif /* USE_FULL_ASSERT */
/************************ (C) COPYRIGHT STMicroelectronics *****END OF FILE****/
Please refer to the video if you would like an explanation of what each of these sections of code does.
Add printf Float Support
In STM32CubeIDE, printf (and variants) does not support floating point values. To add that, you need to head to Project > Properties > C/C++ Build > Settings > Tool Settings tab > MCU G++ Linker Miscellaneous. In the Other flags pane, add the following line:
-u_printf_float
Do this for both Debug and Release configurations.
Run Debug Mode
Build your project and click Run > Debug. In the Debug perspective, click the Play/Pause button to begin running your code on the microcontroller. Open a serial terminal (e.g. PuTTY) to your development board, and you should see the output of inference, which should match our test in Google Colab (estimating sin(2.0)). You should also see how long it took to run inference (in microseconds).
Run Release Mode
The Debug configuration contains a -DDEBUG flag during the g++ compilation, which enables some options in the TensorFlow Lite library (but slows things down).
Select Project > Build Configurations > Set Active > Release. Then, select Project > Build.
Open Run > Run Configurations… In that window, set C/C++ Application to Release/<your_project_name>.elf. Set Build Configuration to Release.
Click Run. Your project should be rebuilt. If you look at the output in the console, you can estimate the flash and RAM usage of your program. Find the output of the arm-none-eabi-size tool.
Text + data gives you Flash usage (about 50,032 bytes in this case) and data + bss gives you estimated RAM usage (about 4,744 bytes in this case).
Open your serial terminal, and you should see the output of your program. By switching to the Release configuration (removing the -DDEBUG flag), we brought our inference time from around 368 microseconds to around 104 microseconds.
Going Further
I hope this has helped you get a start using TensorFlow Lite for Microcontrollers! While the example program is not particularly useful (predicting sine values), it should offer a decent template to start making your own projects.
Here are some other articles about using TensorFlow Lite for Microcontrollers:
- https://www.tensorflow.org/lite/microcontrollers
- https://www.tensorflow.org/lite/microcontrollers/get_started
- Deploying a TensorFlow Lite to Arduino
- Edge AI Anomaly Detection
Have questions or comments? Continue the conversation on TechForum, DigiKey's online community and technical resource.
Visit TechForum