How Do You Train An AI Model?

A deep dive into how models like ChatGPT get built.

I’ve already written about how Large Language Models like ChatGPT and Claude work. But how are they made?

How do you actually build and train an AI model to do all of the amazing stuff that ChatGPT can do?

What is training in the first place?

The process of creating a model is called training, kind of like training a kid to ride a bike, or whatever those people were doing at the Pokemon gyms. In old school Machine Learning – like the kind I went to school for – training broke down into 4 major steps:

  1. Acquire data on the problem: gather a dataset that you’ll use to teach your model to do what you want it to do, like classify an image or predict a stock price.

  2. Label your dataset: data needs context to be useful to the model, like what’s in an image or if a stock went up or down.

  3. Train your model: using some standard algorithms and linear algebra, teach the model what’s going on in your nicely curated dataset.

  4. Test your model: make sure what your model has learned transfers well to new data (and ideally, the real world).

machine learning in a nutshell

In a sense, training a model really is like teaching a kid how to do something, like riding a bike. It’s less about telling them how to do it, and more about giving them repetitions so they can figure out what’s going on for themselves. With some well timed guidance, of course.

In the same sense, a model is a decision making machine. The way you train a model is by showing it many, many different situations and what the correct outcome is in those situations. The model uses some fancy math to learn the patterns in those situations and learns to apply them to new data. And like teaching a kid to do something, the way you train a model – from the method to the algorithms used – vary slightly depending on what you want the model to do.

Let’s run through a few examples.

Examples of ML training

Iowa has the most fertile soil in America (and maybe the world), so they grow the most profitable mechanized crops: corn and soybeans. Farmers want to use drones to scour their fields and detect if there are pests eating away at their valuable plants. So one of them wants to build a model that looks at an image (the input) and tells you if that image contains any bugs (the output). How would you train a model like this?

a brown stink bug on a corn plant

This is a brown stink bug (real name), which is bad news for your corn.

You’d start by gathering (curating, perhaps) a dataset for your model to learn from. You’d get hundreds or thousands of images of corn plants, some without bugs and some with bugs. You’d label each image: does it have bugs or not? And then you’d feed that data to the model. Through some fancy algorithms, it learns which types of pixels in images of this nature tend to indicate bugs, and which tend to indicate no bugs.

But enough about Iowa. What if you wanted to GETRICH by training a model that predicts the price of a stock on a given day? Something that to my knowledge nobody has ever tried to do before???? You want your model to take a new hour or day (the input) and predict whether the stock price is going to go up or down from where it is now (the output). How would you train a model like that?

You’d start by gathering historical data on what the stock price has done so far, ideally down to a low granularity like hourly. Then you’d label that data with information on whether the stock price went up or down. And then you’d feed that data to the model. Through some fancy algorithms, it learns (ideally) which times of day the stock price tends to go up.

How training works under the hood

I keep mentioning these “fancy algorithms” and it’s now time to discuss those. The magic that happens when a model digests your data and learns a pattern is actually not magic it all; it’s math. Specifically, models use something called a loss function to figure out if they’re doing well or not.

A loss function is something you design that tells the model when its answers are right and when they’re wrong. There are a bunch of different types, some catered to specific ML tasks like image classification (bugs) or regression (stock prices).

how a loss function works

In Iowa, one way this can go down is that your model will start by guessing at random whether an image has a bug in it. If it’s right, it gets a point. If it’s wrong, it loses a point. After enough of these iterations, it starts to learn what means bug and what doesn’t.

On Wall Street, your model is trained slightly differently. It plots all of the stock price fluctuations on a graph and uses a loss function to find the most likely relationship between time and price. If the line is good, it gets a few points. If it’s bad, it loses a few points. And after a few iterations, it starts to learn what means up and what means down.

Like I said, lots of different types of loss functions. But behind the scenes, they’re just telling your model “good job” or “try again.” Hence why it’s so important to have nicely curated, labeled data with a clear right and wrong answer.

How do you train generative AI?

Training “old school” ML models is straightforward enough. But what about generative AI, specifically text generation models like ChatGPT? [How do you train those?](https://medium.com/data-science-at-microsoft/how-large-language-models-work-91c...