Introduction to Deep Learning Theory
Date Posted:
September 25, 2023
Date Recorded:
September 11, 2023
Speaker(s):
Boris Hanin, Princeton University
All Captioned Videos Brains, Minds and Machines Summer Course 2023
PRESENTER: OK, so our penultimate talk for the day, because there's one more after dinner. But the last talk of the first theory today is going to be by Boris Hanin. He's a professor at Princeton, operations research and financial engineering. And he works on deep learning theory, probability spectral theory as well. And yeah, start whenever you--
BORIS HANIN: OK, thanks. OK, it's good to meet you all. Most of you I don't know. I don't really come from the neuroscience kind of world. And so it was a bit of a challenge to figure out exactly what kind of talk to give. And I decided to go for a slightly unusual talk.
So I'm not really going to talk too much about my work. I'll mention some things that I do. But instead, I want to focus on a talk based on intuitions. OK, I feel like a lot of people here train neural networks, think about neural networks, one way or the other. But I suspect most people are not experts in deep learning theory.
And so rather than telling you the latest and greatest thing, which, OK, I'm very proud and excited about. But I don't think it's going to be as exciting for you. I want to give a talk that is something like an intro to deep learning theory.
OK, so in particular, I tried to make it as elementary as possible. And so if something is not clear, it's really on me. So just like stop me and interrupt me. And also, because I'm not presenting my own work, I don't have to get anywhere in particular. OK, so that's a nice relief.
OK, so let me just tell you what my notation is for neural network, first of all. OK, so I write neural networks this way. So capital L is going to be the number of hidden layers. x is the input. And OK, I'm just going to talk about fully connected networks, though a lot of what I say can be generalized beyond that. OK, but this is just some kind of notation.
I've turned off the biases because I'm too lazy to write the B's. OK, so my notation is capital L is the number of hidden layers. And then the Wl is the weight matrices. They're Nl by Nl minus 1. OK, so N sub l is the width of my l-th hidden layer. That's what I mean by neural network.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: So nothing I say is going to preclude the need to have biases. It's just it adds more to the picture, yeah. OK, so but my focus is like this. I want to list a couple of questions which I think of as some of the big questions in deep learning. And some of them, we've already encountered today. And some of them, as I try to explain and draw some pictures, will be a little different than what we talked about. And I just want to give you some, what I think of as, useful intuitions for thinking about them.
OK, so here's question number one. OK, we talked quite a bit about this already today. So why does optimization work when the loss-- I'll call it generically script L is not convex? At least, this is a question that was bothering people for a while when deep learning became popular again, which was prehistoric years, like 2012. OK, who remembers that far back? OK, yeah.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Yeah, so some of the exact calculations I do will maybe be sigma equals value or maybe even sigma equals the identity. So it's the easiest kind of calculation. And I'll state some theorems. But almost everything I say will hold for a general non-linearity. That's right. I'll specify when I get there.
OK. So that's question number one. And I want to talk about that. All these questions are related. OK, question number two, we've already heard a lot about. So how do we do feature learning? I'll just write it this way, feature learning and how to make sense of it.
And what about the curse of dimensionality? I'll just write it as COD. These things are related. Somehow, when I was growing up, which was the pre- prehistoric years, I was always taught, you'll never learn a function of a million variables. Really, you won't even learn a function of 20 variables. It's extremely hard to do.
And yet, OK, neural networks seem to do this. And it's not clear how this is supposed to work. And so, in particular, you're not learning generic functions in high dimensions. You're learning functions with some kind of structure. And there's a lot of work trying to figure out what this means.
OK, and the last question, which there were a lot of questions about in serious talk. But I think of it this way. How do you make sense of the alchemy, alchemy-- OK, I can't spell-- alchemy of hyperparameters? OK, what I mean by that is the following. How many people have trained a neural network?
OK. So if you haven't or if you have, a key thing to know is that training a neural network is a disaster most of the time, OK, in the sense that it sounds good. And anyone who says, I'm going to throw a neural network at the problem, is either really a master or has no idea what they're talking about.
It's really like that, in the sense that you have to choose everything. Right? You have to choose sigma and the depth and the width and where to put the skip connections and what kind of optimizer to use. And moreover, the results you get in practice can be highly sensitive to these choices.
So people will say that deep learning is alchemy. You just do the thing that DeepMind told you to do, and it more or less works, something like that. But surely, it would be nice to have some principles. How do you go about making these choices in practice? What's a more efficient way to search over all the things you have to search over?
OK, so my plan is to try to say something about these three questions. And this is, basically, just going to be an overview of heuristics and some theorems people know how to prove. Sometimes the theorems are mine, but, like I said, I won't emphasize that too much. OK, is this OK? Can people read my writing and stuff? Yeah?
AUDIENCE: What was sigma one more time?
BORIS HANIN: So sigma is a non-linearity. OK, so maybe let me just be really clear. So x is an N zero dimensional vector. W embeds x into an N one-dimensional space. And then you apply sigma to each element of that resulting vector. And you keep going in this way. That's just my and most people's notation for this.
So yeah, so I'm too lazy, because I don't do this on a computer, to put the brackets. But I read it from right to left. Most people do it some other way, down to top. Or OK, no one has a unique convention. OK, is this OK?
OK. So let me start by answering or giving some kind of answer to question number one. But before I do that, let me just honestly ask. If somebody asked you, why is optimization in practice tractable? It just seems to work. You initialize. You run gradient descent. You get something. Does anyone have an answer? Yeah?
OK, fair enough. So maybe I should be more clear about what I mean by work. So you seem to find, in practice, something close to a global optimum for the loss. So the kind of picture that people use to draw for me to scare me, when I tried to be naughty late at night, was this kind of picture for the loss landscape.
OK, I would wake up. I'd want a cookie, and my parents would draw that. And I would be scared back to bed. Why can you find the global optimum in a sea of bad local minima?
So my question is, why is this not the loss landscape for a deep neural network? And what do we even know about the loss landscape? It's in a million dimensions, so we can't quite visualize it. OK, Siri is not allowed to answer, for what it's worth.
Right. I mean, so I would say, we understand very little broadly about the loss landscape. There's lots of people who've done interesting work to try to say something. But I would say we know less than what we don't know, if that makes sense. Agreed? Very good.
So here's my like attempt to answer it. I'm going to try to draw a picture and a very vague mathematical intuition and a theorem, which, in some case, says this is true. OK, so I think really the answer-- and Tommy really told us this-- is overparameterization. That's really what makes optimization, essentially, easy. OK, I don't know if I spelled it right, over-- I think that's OK maybe.
So this just says that the number of parameters in your model-- OK, in my silly example here, just the number of weights that you have-- is typically, much, much bigger than the number of data points, which itself is typically much, much bigger than 1. OK, that's the regime in which we often work.
And people will say, wait a minute. Big language models don't necessarily have this. OK, it depends what you mean by a data point. How do you measure the effective number of data points? It's unclear.
But let me at least say, that whenever you're vastly overparameterized, you should expect to be able to easily minimize your loss. OK, if you don't want to think like a mathematician, but rather like a physicist, this, I think, after I draw a picture, is going to be obvious. OK, that's my goal to try to convey.
OK. So why is this? What's the picture I have in mind? Well, the picture I have in mind looks like this. OK. So over here, I'm going to draw R number of parameters. This is where we typically think of optimization happening. Yeah?
Nothing yet about-- no, correct. I'm just saying, an empirical fact is you fit the training data, at least almost always. And why should that be? I'm going to try to picture proof it for you.
OK, so here you are. And for all you know, in the space of parameters, the loss looks really complicated. OK, I'm going to draw the level sets of the loss. Maybe the loss level sets look like this. OK. And so that's a total disaster. It looks like maybe these are two local minima. Who knows?
OK, and you're at some point theta. And you're desperate. And you don't know what to do. And the point is, on a computer, you get to move your vector of parameters theta in any old direction that you want. I mean, you choose the gradient of the loss. But these are free parameters for the model.
But it's not obvious why you should be able to minimize the loss. OK, and in particular, I haven't used anything about the structure of the data set. So I've got to put that in there somehow. So what am I going to do? OK, I'm going to do the only thing I know how to do, which is draw another box.
OK, so we're going to go from R number of parameters to R number of data points. OK, I'm told, although I literally know nothing about neuroscience, but I was once told by somebody who is a neuroscientist, that neuroscientists think about this. OK, so instead of thinking about the value of the set of parameters you get, why not just think about the predictions you get on the training data?
OK, there's a nice map which takes this parameter vector, that I called theta, and maps it. So I'll write it down here. Theta maps to just the set of values that your model takes on the training data points. I'll say xi, or the training data points. And theta is my setting of parameters.
This is a number of data dimensional vector. Just for simplicity, I'm saying the outputs are scalars. So I don't have to think of them as vectors. So each setting of parameters, I get the predictions on the training data. OK, very good.
And the kind of losses we use, think of a regression loss, mean squared error. Well, it's very complicated over here. But what kind of function is it of the outputs of the network? Huh? I heard something.
AUDIENCE: Convex.
BORIS HANIN: Convex, OK, yeah, it's trivial. OK, exactly. So if you were to draw the level sets, this is the dream of every person in optimization. These are the level sets. And there's just a minimum, and it's very clear where to go. Right?
So suppose, for example, that this-- I'll just write it as z of theta. I won't put all the notation down there. OK, at a given vector of parameters, you go somewhere in the space of predictions. OK, and it's very clear over here where you want to go. You just want to go straight towards the minimum. This is the gradient of the loss with respect to the param-- with respect to-- sorry, the outputs of the network. You want to go that way.
OK. So now let me tell you a story about why optimization should be easy. It should be easy for the following reason. If there exists a direction to go in the theta space, so that the resulting move in the z space overlaps with the direction I'd like to go, then I'm always going to win. I'm always going to be able to get down to the bottom of the well for the loss as a function of the network outputs.
OK. So here's the point. I'm now going to use that the number of parameters is much bigger than the number of data points and say, look, there are many, many directions I can go over here. And I just have to be able to go one particular special direction down here OK, so here's the intuition, intuition.
So when the number of parameters is much bigger than the number of data points, what you should expect is that the Jacobian of this map-- so the Jacobian of the network output as a function of theta. So again, you fix theta, and you ask, what's the derivative with respect to theta of all the outputs of your network evaluated on your training data points? This should have full rank.
OK, if you've ever taken a differential geometry class, this is called being a submersion. That's a nice mathematical property. And all it says is that, on this tangent space here, the Jacobian which maps you from there to this tangent space over here is surjective. Any direction you want to go, there exists a direction up here that maps to it.
OK, so now it's going to be one line of algebra to convince you that you're never going to get stuck in a bad local minimum. In fact, bad local minima will never exist OK, so let me write it down, and then I'm going to let you digest it and ask me questions or tell me to hurry on.
OK, so let's assume that you know that for every setting of parameters theta, this Jacobian is literally full rank. It's maybe a little too much to ask for, but that's what I'll assume. OK, so let's ask, what happens at a stationary point of the loss?
OK, so here you are. And basically, gradient descent is going to get you to a point where the gradient of the loss is equal to 0 with respect to the parameters. That's what you should expect from any local method. OK, well, what's the only thing I know for multivariable calculus? OK, that's not true. I know several things.
What's the main thing I know, though? I know the chain rule. The loss depends on theta only through the values of the predictions on your training data, at least if I haven't regularized. So I'm just going to say, this implies that the gradient with respect to theta of z times the gradient with respect to z of l is equal to 0. That's just replacing the left-hand side by the chain rule.
OK, and now I'm done, right? If this is a full rank map, if this is surjective, then the only way that this times a vector equals 0 is if this vector equals 0. OK, so this implies that the gradient zl is equal to 0.
But l as a function of the network outputs only has one critical point, which is a global minimum because it's convex. That's the whole point. So therefore, l of z of theta is 0. It's minimal if I'm just thinking about mean squared error.
OK, so this is my mental model for why you should think, any time you have a lot of parameters has nothing to really do with the neural network. Anytime you have a lot, a lot of parameters, you should expect optimization to be easy. And that's just because data points don't have to fight each other.
Some parameters can help you fit one data point. Some parameters help you fit another data point. Everybody sings "Kumbaya." Notice, so full rank just means it has-- because the space down here is lower dimensional, it means that you will be able to move in as many dimensions as the number of data points, not number of parameters. This is a rectangular matrix.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: No, no, no, no, no, sorry. This is a matrix, which is the gradient with respect to theta z at x1, the gradient with respect to theta, the output of your network on z number of data points, OK, so it's a number of parameters by number of data points matrix.
AUDIENCE: I see.
BORIS HANIN: OK, or maybe the transpose, depending on how I wrote it. So because the number of columns is much smaller, I'm only asking that the columns are linearly independent. And you can translate it like this. Being full rank means that there exists a way to go in parameter space that does anything I want in the space of outputs. And that's reasonable.
If you only have one data point and a million parameters, I'm just asking that there's some way to move the parameters so that it moves the prediction on that data point. So, well, yes, because l is a convex function of the network outputs. The only minimum of a convex function, so the only critical point of a convex function, is a global minimum.
So what I've tried to argue is, that any time you reach a stationary point where the gradient of the loss with respect to the parameters vanishes, that's what you should expect gradient descent or a first-order algorithm to get you to. Automatically, the gradient of the loss with respect to the network outputs needs to vanish. And therefore, you must be at the global minimum. You've exactly fit your training data.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: That's the argument. So so far, nothing about generalization. I'm not sure if that was part of the question. Yeah. OK. It's OK. Yeah, please?
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Well, I mean, honestly, that's one of the first things I was taught about actually using neural networks. You should just take a bunch of data and make sure you can exactly fit it. And if you don't, you need a bigger model.
OK, I'm not sure if that's standard advice anymore, in a sense. But I think the short answer is, yes. I think the long answer is, no. And it's ironic because yes is longer than no, but yes.
Yeah, yeah, right, indeed. And if you have much more degrees of freedom than you have equations, it should be easier to fit. So we've seen this kind of theme. I'm just saying, this, for me, is a rather convincing heuristic. And it's all about, OK, is it really true that you have full rank, and so on? Yeah.
So you should think of it literally as the gradient of the output with respect to the parameters. All these vectors are linearly independent. So you could freeze all but one of the data points and move the prediction just on that data point. That's what it's telling you.
And if you can change one data point at a time, you can never get stuck in a local minimum, right? Because if I was in a local minimum, why don't I just move that one data point. And then I would get lower. And so there's no bad local minimum. That's the kind of argument. Yeah, great.
OK, strong? Very strong? Partially strong? OK, so that's my intuition for this. OK, so now let me tell you a theorem. OK, I can't help myself. I'm a mathematician by training. If I don't have at least a single theorem, someone's going to roll over somewhere. And I don't want to specify further.
OK, so before I try to talk about feature learning and curse of dimensionality and anything like that, let me just state one cool theorem, which is a very, very, at least in retrospect, humble situation, in which we know that this is OK. OK, so if you've ever heard of the NTK or the lazy regime or the kernel regime, OK, or the linear regime, OK, anything that has a lot of names means it's been studied by a lot of people. That's my heuristic.
OK, have people heard about this? Have people heard these words, NTK, kernel? OK, whatever, I'm going to state a theorem. And it will basically be a theorem which says, there exists a regime in which these neural networks are extremely simple. And it's really easy to check that this Jacobian has full rank. OK, and therefore, optimization is going to be successful, and so on.
OK, so let me figure out exactly how I wanted to say it. Because I thought of an efficient way of stating it. OK, yeah, so it goes like this. So let's fix the training data. OK, you have a fixed, let's say, 10,000 training data points. But that's not going to change.
So, in particular, I'm going to fix the input dimension, and I'm going to fix the output dimension, OK, and I'm going to fix the depth of the network. They're all going to be fixed. It's a theorem, so you have to fix some constants before you do anything.
OK, so then, if I'm going to talk about training, I'm going to initialize in a particular way. This is like PyTorch default, if you want. So I'm going to initialize my weights like this. They're going to be independent Gaussians with mean 0 and some variance that scales a constant over the width of the previous hidden layer. OK, that's the way PyTorch typically initializes things. Very good.
OK, so then, I'm going to-- OK, as the hidden layer widths N1 up to Nl go to infinity, something kind of magical happens. OK, so before I tell you what happens, let me just say, this is a very common thing to do in physics. If a neural network is very big, maybe you might wonder what an infinitely big neural network does. And this is one particular way to understand infinitely big neural networks.
You fix the training data set. You fix the input, output dimension, the depth. Everything, you just ask, what happens when I scale the hidden layer widths and make them bigger and bigger and bigger? OK, so neural networks become shockingly, sadly, overwhelmingly simple.
OK, so two things happen. So at initialization, so at the start of training, when you just select your weights IAD at random, the neural network itself as a function, so of x, x is the input. This converges to what's called a Gaussian Process. OK, so if you don't know what a Gaussian Process is, that's OK. Just think of it as a Gaussian.
And the key point maybe-- OK, I probably go too far, but keep away from a neuroscientist, OK, with independent neurons. OK. So there's no correlations among neurons anywhere in your network. If you take any finite number of neurons, let's say the outputs of the network, they're just completely independent at the start of training. There's no propensity to actually compute any interesting, correlated signal.
OK, but not only that-- OK, so during optimization, optimization, so let me say, by gradient flow-- or just think of gradient descent with a sufficiently small learning rate. So just literally the thing you think, this kind of dynamics, OK, you take L. Always L, for me, for now is the mean squared error. So I have some fixed data set.
I'm just minimizing the mean squared error by randomly initializing my vector of parameters and doing gradient flow. OK, so what happens when you do this? Well, it's really weird. So during optimization by gradient flow, you can just replace the beautiful nonlinear neural network that you were hoping, like Tomer and Akshay before us, to do some transfer learning with or learn something interesting with, you can replace the whole network function by its linearization.
OK, so I'll call it z lin of theta. So let me write it down, and then I'll explain what I mean. So this is the network frozen at initialization, where you only keep the first order term and the Taylor expansion, OK, theta minus theta 0. Oh, wow, that's just enough room.
OK. So let me just, again, say this in words. This means, that instead of trying to understand what training looks like with your neural network, which is complicated and nonlinear, and so on, you can just completely punt. You say, well, neural networks, schneural networks. I'm just going to consider the following linear model.
OK, it's just take your neural network, and you consider the first order Taylor expansion of your neural network around initialization. And you just train this thing. This thing is a linear model. This is the only place the theta appears, right? Everything else is theta at the start of training unless I made a mistake.
So it's really crazy. There's this regime where you make the network wider, wider, wider, wider, wider, at least in the way I'm specifying. And the network gets closer and closer and closer and closer to just being a kernel method, or a linear model, OK, linear as a function of its parameters. That's what it means to be a linear model.
OK. So let me just pause for one second. There's a lot of theory that came out of this. And I'll critique this in a second and tell you what comes after. But is the statement of what I'm saying kind of clear? It might not be clear why it's true, but at least the first time I saw it, and many people, it was somewhat shocking.
So all the assumptions are here for a reason. OK, if you break any of these assumptions, it's not going to be true. But I'm just saying, it's crazy. We know linearize models. We've known them for 100 years, since before any of us were born.
And somehow, neural networks have hidden inside them, at least in a certain scaling limit, linear models. That's really cool. So just to check some understanding, what happens to this Jacobian? So if instead of the neural network, you just allow me to use this linearization, what's the Jacobian that I get?
AUDIENCE: It's constant.
BORIS HANIN: Yeah, it's constant. OK, great. Which constant is it?
AUDIENCE: [INAUDIBLE]
BORIS HANIN: OK, sorry, so indeed. So OK, I don't know your name, but I can't fool this person. They say, well, OK, it's really easy to compute the gradient of a linear function. Theta appears only here, so I just get this thing. That's my Jacobian right there. It's just a constant.
OK, so in other words, what this is saying is, if your network is wide enough, all you have to check to make this picture precise is just that at the start of training, your Jacobian has full rank. So because it's frozen at initialization, I mean, it's constant as a function of time--
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Yeah, yeah, sorry, that's the key point. You linearize only around the start of training, and then you keep that linear model forever, right. OK, so that's what this whole line of work is about. So if you're ever interested in it, OK, now you know. And it's all about, under what conditions? How wide do you have to be, blah, blah, blah, blah, blah? But I won't go into it too much.
OK, so that's my story about a little bit of optimization and one situation in which it's easy to check that this heuristic I have is true. But I think, as many people here already, basically, know-- well, these erasers don't work. And this is really sad. So maybe let me ask you, especially if you've never seen it, why should this theorem make you sad rather than happy?
Or maybe it makes you happy. Maybe you're a contrarian. I don't know.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Yeah, it's pretty sad. Didn't we just hear talks about transfer learning and learning features and all of this? It's exactly the kind of thing linear models can't do. Ooh. OK, so I'm saying, from my point of view conceptually, all these theorems are really useful in the sense that they suggest one concrete mechanism by which optimization might be easy in some practical-ish regime. But they're very sad as a model to study because you can't actually study feature learning. You have to start there.
OK. So which brings us to feature learning, question two, very good. OK, so to escape the NTK regime, there are many things that you can do. OK, basically, the rule of thumb is, anything you do will escape the NTK regime. OK, so let me give you a list of options, and then you can tell me what you like.
OK, so you can change the learning algorithm. Change learning so you can change the loss. Or you can keep the step size big. Or you can add regularization. Or you can do almost anything to change the learning, and you won't any more be exactly a linear model. OK, that's already kind of cool, but harder to study.
So what else can you do? Well, you can consider the regime where the number of data points scales with the width of the network, scales with the width Nl of your various hidden layers. So this is not unreasonable. We do in practice, in fact, do this nonparametric thing.
The bigger the data set, the bigger the network you use to train. And this seems to really help. And indeed, doing this, you get totally, totally different answers and dynamics. And we understand very little of what happens.
OK, so the other talk I was considering giving you is I just proved a result about this. OK, but it was niche if you don't have this intro in some sense. So I decided not to go for it. So the short answer is, yes, there is a simple intuition.
And rather than spoiling it for you-- so consider the case of just one hidden layer, and make the following observation. What does it mean to be a linear model? OK, it's a little embarrassing to ask, what's the definition of a linear model? But nonetheless, well, OK, there are several ways to define it.
But maybe the most convenient way is, by definition, a linear model is one in which the Hessian equals 0. Right? Wouldn't you agree that if the Hessian equals 0 with respect to parameters, you're a linear model. OK. So the key point and one cool conceptual way to prove these theorems is to show, that as the network gets wider and wider, the Hessian actually becomes smaller and smaller. Whereas, the Jacobian, the actual gradients you use for training remain order one.
AUDIENCE: Why does the Hessian get smaller?
BORIS HANIN: So, essentially, the Hessian gets smaller because it's the easiest to see maybe in a one-layer network. The Hessian tries to mix the derivatives with respect to all the parameters. But the only ones that can survive are the parameters that correspond to a single neuron. And so the Hessian actually has this block diagonal structure. It's mainly zeros. It's just not a dense matrix anymore.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: So all I'm saying is do the calculation. You can do it. Write down a one-layer network, and you'll see a beautiful block diagonal structure on the Hessian. And you'll see that the Hessian becomes small, whereas the Jacobian stays order one.
AUDIENCE: I just want [INAUDIBLE].
BORIS HANIN: Well. so that's how you get a feel.
AUDIENCE: No, [INAUDIBLE].
[LAUGHTER]
BORIS HANIN: No, you do the calculation, and then you get the feeling. That's how. So I mean, I'm happy to do that calculation with you after. It only takes five minutes. So but yeah, it's not totally obvious. It's not completely clear at first why this should be true. I'm just saying this is a cool result.
OK, so you can take this. You can also consider the case. Consider the depth, L growing with N. So one of the questions I've been obsessed with for many years is, what's the role of depth in deep learning? It's a little sad. Linear models are depth zero networks.
If you have zero hidden layers, you're a linear model. And but wait, I thought this was a depth L model. So how did it become a linear model? That's related to your question. And there's a bunch of stuff that I've been working on this direction.
And indeed, if you think about networks that are both deep and wide, actually you're not close to linear. OK, did I have anything else? Let me see. Oh, yeah, let me say one other thing.
So you consider a different initialization, different init. OK, so it turns out, if you've ever heard of mu p or mean field initialization, or anything, instead of taking 1 over N in the variance, if you change it, and you make smaller weights in the final layer, you get a totally different set of dynamics that happens. OK, again, it's not completely obvious why, but I will try to draw a picture for you if I have time.
I have no idea when I started. And so this might go on forever, as far as this is concerned. So someone's going to have to tell me. Yeah?
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Yes, I'm saying the entire trajectory of optimization will become that of the linear model. Well, I mean, the thing is, as you send the width to infinity, you don't have a fixed network. So really maybe the thing to say is, if you just consider the optimization trajectory of your neural network at sufficiently large width, then it will be arbitrarily close to the optimization trajectory of the linearization. Maybe that's a precise version of the statement.
And then that difference goes to 0 as you go wider and wider, yeah. OK, very good. So OK, let me tell you something about how to think about the role of depth because that's one of my pet things. And I can draw a simple picture about it and even show you a simple calculation.
OK, if you've ever heard me give a talk before, I will reuse one of my favorite jokes, which is that every successful talk, at least for a mathematician, has to have a theorem, a picture, and a proof. OK, but the key thing-- oh, sorry, a theorem, a picture, a proof, and a joke. OK, that's the key.
And the key thing is that the joke and the theorem can't be the same thing. OK, so here's the theorem. And that was my joke. And OK, I'm about to do the calculation, and it'll be good. OK.
So let me double-click, so to speak, on this thing. And I want to convince you of yet another piece of intuition. OK, so how should you think about the role of the depth and the width and the network, at least without skip connections? You can generalize it to skip connections, but the answers are somewhat different.
OK, so for this intuition, I want to think like a person who does dynamical systems. OK. So I'm going to draw L down here, and I'm going to think of my hidden layer widths as just being generically denoted by N. OK? So I think a productive way of thinking of a neural network is to think of it like a dynamical system, where your time 1 evolution is go through one layer of the network.
OK, every time you apply a transformation , and something interesting happens. So if you are serious about this analogy, and you can tell I'm serious by how excited I am, L is like a time parameter. That's how you should think about it, not unreasonable.
OK, but if you then think about it, what's the role of the width? Well, it's just literally the size of the state space, number of degrees of freedom. So I'm going to say degrees of freedom, not dean of faculty, OK, DOF. OK, so already, I want to try to convince you that L and N fight each other. OK, they're on opposite sides of the war. I don't know what they're fighting for, but then, again, most wars are senseless. OK.
So why should that be? Right? I mean, if you really take this analogy of dynamical systems, can you explain to me why N and L play different roles or pull in different directions? OK, great. So indeed, the point is, that maybe if I can rephrase it slightly, bigger systems, ones with more degrees of freedom, take longer to come to equilibrium.
OK, so the bigger I make N, the larger I should make L in order to watch them correlate the inputs to the network or correlate the dynamics. OK, and indeed, so here's the meta theorem. I'm going to put it over here. I'm just saying, consider this as like an analogy for now. And I'm about to write a precise version of this analogy.
But I'm just telling you, that when you process an input through many, many transformations, what different inputs share is they share the same weights. So you should expect them to become like somewhat more correlated as they go through the network. That's a generic behavior.
But the bigger the size of the input, the longer it takes for the correlations induced by the network to beat the differences that came from the input. It's just like a heuristic. So here's my meta theorem, and then I'm going to do a calculation for you. But this is going to answer the question of, just how wide do you have to be in the NTK regime that I just erased?
OK, so here's my meta theorem. And then I'll do the calculation on the left. OK, so when the Nls are large but finite-- OK, so when you have a pretty wide network, but you just haven't sent them to infinity, then here's how you measure the effect of depth. OK, the effective depth of the model, how far you are, if you want, from being a linear model is actually the sum of the reciprocals of the N's, which you should just think of as the ratio of the depth to the width, if you think of all the widths as being roughly equal.
OK, so I know it's a little vague, but let me put it this way, e.g., so the distance between your neural network and its linearization turns out that it goes L over N. So at a fixed depth, I told you that you get closer and closer and closer to being a linear model. That was the NTK statement.
But what I'm saying is that depth actually amplifies the non-linearity that you have in your network. And you get farther and farther from being linear when you go deeper and deeper. OK, so that's a basic thing.
The correlations between neurons, all right, this way, if you just look, how much do different neurons wire together? Even just at the start of training, how much propensity do they have to be correlated? OK. This also scales like L over N.
So many, many effective properties that try to say, how non-linear are you, go to 0 when N goes to infinity, but tend to grow when L grows. And they grow in this ratio. OK, and my purpose here is I want to show you the simplest calculation possible to give you a little bit of intuition for why it's this ratio, as opposed to something else.
So things become more correlated as you go deeper and less correlated as you become wider. Yeah, sorry, so I mean, I'm being a little vague just so that I don't have to do all the notation. But just think like this. Think of just fix an input to the network, and consider the correlation between some two neurons in the network, let's say two neurons in the same layer.
I'm saying, as you go deeper, their pre-activations will become more correlated. And as you go wider, they'll become less correlated. OK, let me be a bit more precise. So let me look at the correlation between, let's say, the first neuron in the output at x. OK, technically, you have to put a square here because the first correlation vanishes, just due to symmetry of the weights, and the second one.
OK, so the correlation between how strongly the first neuron fires in the output, and the second neuron fires in the output at the same input. And the randomness just comes from the initialization. So I'm saying, this scale is L over N. That's precisely what I mean. Yeah?
No, we're about to calculate the identity case. Even that case is interesting, although it doesn't seem interesting at first. But OK, Surya told us a lot of interesting things by studying deep linear networks. And I'm going to try to tell you something interesting about studying deep linear networks. I can do the calculation quickly. Yeah?
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Correct. So you're totally right. And I'm brushing under the rug something that Surya alluded to, which is a question of, OK, if you want a dynamical system not to do something crazy, you have to tune the parameters of the dynamical system to be close to criticality. And then once you do that, and you study things at large depth, then this L over N emerges.
Otherwise, things are exponential in the depth. You either contract or blow up, or something crazy happens. So I'm asking you to trust me a little bit so that I don't have to write all that out. But yes, you're correct.
OK. So let me do the calculation. My advisor would kill me if I didn't do this calculation. OK, so here's going to be my intuition. I'm going to just take sigma to be the identity, as promised. OK, I'm going to take x to be an input of unit norm. And I'm going to take my weight variance to be 1.
OK, so I'm just going to study the simplest possible thing. And you'll allow me, hopefully, to go from L to L minus 1, from L plus to L. I'm just going to study, what is the distribution of the output of my neural network at a fixed input? How does this actually behave as a function of L and N?
OK, this is the simplest kind of thing you could do. Imagine you're just at initialization. So all of these weights are IAD Gaussian. OK, the Wl ijs, they're all Gaussian mean 0 variance 1 over N. All my hidden layer widths are the same size. I should write that. Nl is equal to N. It's the simplest possible thing, OK, just so I don't have too much notation.
So this is a deep linear network at the start of training. And I'm just asking, what does this look? That's a moderately down to earth question. But already, it's interesting and not totally trivial, as you're about to see.
OK, is it clear what I'm trying to do? I'm trying to assess the effects of depth and width by studying what's arguably the simplest possible statistic about a neural network that I can imagine. You just have no non-linearity random weights. And I'm just asking, how big is the input when I've normalized-- sorry, the output when I've normalized the input to have size 1? That's literally what I'm studying.
OK, so let's do it. I need four lines of calculation, I think. So here comes my calculation. OK, what should I do? Any suggestions from the peanut gallery, so to speak? Probably Surya knows all the tricks, so I won't ask him.
But if you had to understand the distribution of this random variable, fixed input, product of Gaussian random matrices norm, what do you do? Not quite. Not quite. OK, so let me remind you, if you have a Gaussian matrix times a unit vector, this product is just a standard Gaussian vector.
So let me do this. Let me do the only thing I can think to do. Let me multiply and divide by how much the first layer scales my network. OK, that's a legitimate operation. Right?
OK, so think back to your probability classes. OK, maybe you never took a probability class. But just say you did. So OK, this is the direction in which a standard Gaussian vector points. OK, you take a Gaussian vector and N dimensional space, means 0 identity covariance, and this is the unit vector in its direction.
So what's the distribution of this thing? OK, I got the uniform on the sphere, right? The Gaussian doesn't have a preferred direction. Very nice. Moreover, the direction is independent of the length. OK, because the density of the Gaussian is purely radial. It factors as the uniform measure on the sphere times the--
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Yeah, yeah, I'm saying, the distribution of this whole unit vector is independent of the size of the Gaussian. That's because the Gaussian has a density that's radial. OK, so amazing, I claim I'm done. It's just you don't see it yet maybe.
OK, so this is the norm of a Gaussian, so also known as a chi squared random variable with N degrees of freedom. And the normalization is 1 over N due to my variance, OK, with a square root, because I don't take the squared. This is just the square root of the sum of the squares of independent Gaussians. So that's a simple, well-known distribution.
And now the rest of this is the same problem I had, but with one fewer weight matrix. This is just some unit vector, and I keep going. OK, so the point is like this. You keep doing this, and you say, in distribution, this is just the product of independent random variables. l goes from 1 up to capital L. And you take a bunch of chi squared random variables that are appropriately scaled, and you raise them to the 1/2 power.
OK, so we exploited very heavily the fact that everything is Gaussian to shortcut to the answer. But already, you see something amazing. And I'm about to say something slightly more about it.
So each of these chi squared random variables tries to be closer and closer to 1 as N grows. This the sum of squares of appropriately normalized Gaussians. And so if you know anything about Gaussians and high dimensions, their norm tries to concentrate on the unit sphere, if you scale them the way that I scaled them. So this thing tries to get closer to 1.
But now you're taking, as L grows, a product of more and more of them. And that tends to exacerbate the fluctuations. OK, so let me write one more formula down because I can't help myself. I'm just going to exponentiate this formula and write it as x sum log. OK, I'm going to put the 1/2 write the log of 1 over N chi squared, N degrees of freedom.
OK, so what do you learn? You learn that the distribution of this very naive-seeming random variable is quite interesting. It's e to the sum of IID random variables. OK, that's exactly what it is. Each of these chi squareds is independent of all the others. So you can apply the central limit theorem.
OK, at large L, that's the number of terms in the sum, you just have to compute the mean of each of these and the variance of each of these and add them up. And just trust me when I say, you get an e to a Gaussian with mean minus L over 4N and variance L over 4N. That's exactly what you get when you apply this CLT.
OK, so the main thing I want to say, even if you didn't follow all the details of the calculation, is that the answer depends only on the ratio of L over N. This is the simplest way I know how to see L over N appear. And moreover, what happens is, each layer has 1 over N correction.
But then you add up L of them, and if L and N scale together, that's how the correction scales. OK, L over N is the thing that controls how close you are to the start of training or to being 0 depth. OK, so instead of mathematically taking N to infinity at fixed L or L to infinity at fixed N, you should really think about taking them to infinity to gather at a fixed ratio. So L over N plays the role of an inverse temperature. Here, beta is equal to 0. Here, beta is equal to infinity. That's the right way of thinking about what happens.
OK, so things become subtle when you go deeper. People ask, what's the effect of depth? There's a lot of possible answers to it. There are approximation theory answers. There are optimization answers.
But there's also answers which basically say, the deeper you go, the more nonlinear you are, the more you have the ability to learn features. OK, that's some very pale shadow of the fact that deeper networks, at least for moderate depth, seem to learn better features. And this kind of explains a little bit what's going on.
OK, so let me stop rambling about this for a second. OK, let me just pause for a second and let people digest this. And I'm going to write one more thing in this row here, which I think is cool. OK, so if you also study, just to connect to what we had before, the variance of the gradients-- OK, so these are the gradients you use to actually do optimization.
OK, so what happens is the variance of the gradients also grows like L over N. And you get this simple-minded story, at least for fully connected networks. So the theorems I'm alluding to work for non-linearities as well.
So as you make L over N bigger, you learn features. You're farther from being a linear model. But as you make L over N smaller, you're more stable. You can actually optimize when you do gradient descent.
And that's what I'm saying here. If the variance of the gradients is too big, you get what are called exploding and vanishing gradients. It's very sad. Optimization will not work, even though your model was beautiful.
But if L over N is too small, you're too close to the NTK regime, then OK, you're going to train. You'll be stable, but you won't learn features. And you've, again, been fired from your job, which you've worked so hard to obtain. OK, so you've got to keep moderate L over N. That's my point, unless you use transformers, in which case, I'm not 100% sure what the analog is. OK, it's an interesting question.
So remember I started with this intuition that L and N fight each other. And it was very vague. And so I wanted to show you the simplest calculation which shows how they fight each other.
And so this is not a calculation which is meant to be useful for making grand conclusions. But it's the most direct, actual complete calculation which shows, that as L gets bigger, things have bigger variance. They fluctuate more. The variance of your Gaussian grows, grows, grows, grows, grows.
But when N gets bigger, the variance N mean of the Gaussian get closer, closer, and closer to 0. Things become more tame. And the mechanism is each layer tries to push you away from being Gaussian, away from being very nice. And then as you go wider, you get closer to being nice. That's what this calculation was meant to illustrate. Yeah?
So I'm saying, at initialization, I'm saying you haven't done any training yet. How stable is your network? Are you going to be able to run gradient descent or not? And I'm telling you, if L over N is too big, you're not going to be able to run gradient descent. It's just going to be wild, and you won't really learn anything. Or it'll take you a long time to start learning.
But if L over N is too small, OK the, theorem I erased says you will be able to run gradient descent, but you'll just be learning a linear model. So it's like some effective description of how non-linear you are or how chaotic you are, maybe, yeah. OK, very good.
So let me stop talking about this, and let me talk about one more thing, which I think is pretty cool. Oh my god, I have three more things prepared? OK, well, I'll just talk about one of them and see how that goes.
So let me in, the spirit of being like Surya, which should be the goal of all of us, I feel, OK, take a vote as I try to erase. So there are three different options of what to say next. So I'm trying to give you intuitions, right? I told you about why optimization should be easy when you're overparameterized, something about how to measure the depth of a network. It's a little more subtle than you thought.
And there are three other things I can talk about. I could talk about how neural networks might be connected to optimal transport. If you've ever heard of optimal transport, that's pretty cool. I can talk about something about algorithmic regularization. We've already heard about implicit bias of SGD. But I'd like to draw a picture for you about the implicit bias of optimization algorithms. That's something I dream of doing.
Or I could tell you something about why double descent happens. Oh. OK, who wants to hear about implicit bias of algorithms? OK, we've got a couple of votes. Who wants to hear about optimal transport? OK. That seems to be more exciting. And what was the last one I said? I already forgot. Oh, yeah, who wants to hear about double descent? Oh. OK, well, I didn't count.
Fine. I'm going to say something about double descent. OK, I'm just going to brutally break the tie this way. And I'm happy to talk, as you can tell, offline about anything for any amount of time.
OK, fine, so what's a thing that people thought was really incredible and had to be something deep about neural networks and turned out not to be at all? OK, as a spoiler-- so you make the following kind of plot. On the x-axis, you plot some notion of complexity of your model, something like the effective number of parameters of your model. OK, think about a network that's getting wider and wider or deeper and wider, or whatever.
OK, and on the y-axis, you plot the test error. And then here, you fix the number of data points in your model that you train on. So here's number of data. OK. So what's the picture people get when they do this in practice, whether it's neural networks or other things?
Well, when you're underparameterized, when the number of parameters is less than the number of data points, you get the usual bias variance trade-off or something like it, which tells you that, if you fit the data too well, you're going to overfit. If you don't have enough parameters, you're not going to fit the signal, Goldilocks principle, whatever.
OK, but then, people didn't really think, at least not very conceptually or in great generality. But what happens when things are overparameterized? But what you find, in fact, is that in many cases, things are OK. When you're overparameterized, we've already been through. This there's infinitely many ways to fit the data.
But somehow, the algorithms that we use in practice, these greedy, local searches that we do, actually find nice ways of fitting the data. And life is actually pretty good. And so there is this big mystery about-- so, OK, they call it double descent. It's not very surprising or creative. OK, those are the two descents. That's all they mean.
OK, the truth is things can be more complicated. In some models, you can actually see and prove that there are other intermediate behaviors. But let me leave that aside, OK, multiple descent, whatever.
OK, so the thing I want to maybe say is, why is this picture reasonable? We first discovered it, quote, unquote, for neural networks. But it's true for decision trees. It's true for linear models. It's true for almost everything you've ever heard of.
And the question is, why? What I really want to focus on, because I think people don't talk about it that much, is why do you expect these spikes? What causes the spike to happen when you have bad test error? Why does it happen right at the interpolation threshold when you have just enough parameters to fit the data? What's going on?
So maybe as I erase, someone can tell me what their best guess is. Does anyone have a good mental model for why you have a jump in the test error right at the interpolation threshold? So why does this happen?
OK, let's just think of a simple linear model, like z of theta is whatever, theta transpose x. And you generate a bunch of data. yi is theta star transpose xi plus epsilon i, some independent noise, simplest possible situation. I claim, when you draw this kind of picture, at least if you don't regularize it, if you just use the mean squared error, L of theta is some i goes from 1 up to a number of data points. OK, you know what it is better than I do.
But I can't help myself. Theta transpose xi minus yi squared. OK, when you actually just either take the minimal norm solution-- or let's actually think about the blow-up from the left. So there's a unique minimizer of this loss. So I'll try to show you why, as the number of parameters grows to be the number of data points, you expect a spike in your test there. OK, why should this be?
And I don't think there's a theorem about it. But my sense is that this mechanism is totally generic. And it's the reason for spikes and all the test losses that people see. It's just harder to prove or really get your hands on maybe.
Let me put it this way. So what you're saying is totally reasonable. But maybe let me better ask you, in how big of a window around the number of data points should I expect the test loss to be blowing up? Maybe that's a more honest question to be asking that I really have in mind.
I agree with you, that when things are super rigid, you have one parameter per data point, you somehow overfit for the noise. But OK, what if it were a neural network? Then it's harder to think about, right? And I claim there's a very generic mechanism, like that picture with the overparameterized optimization that I was writing down, for why this is going to happen.
So OK, so let's start with an easier question. This is like the kind of question you have to do in order to get a job in like 2013 working for Facebook AI. OK, what's the minimizer of the loss? What's the formula for the minimizer? I swear they asked me when they were interviewing me at some point for a visiting position. And I was like, oh god. That was my response.
So let me see, I have to take theta transpose. OK, I messed-- OK, it doesn't matter. So xT x inverse, let's check. x dy, that's your claim? I think that's almost right. It's almost right. Let me do it-- I don't think that's quite right. I think I need to put the xT over here maybe. That look OK?
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Right. Oh, no, the x. It's an x. OK, now it's OK. OK, I think I got it right. [INAUDIBLE] yes. If you take theta transpose times x, then the x transpose x cancels, and I get y. OK, I think I win. OK, very good.
So there's a formula for it. That's the beauty of the linear model. If anyone tells you that you can't understand things, you can say, well, I can understand anything for a linear model. There's a formula. But you don't really have access to the true theta star. You have to estimate it from the data.
And really, this is x on the training data, right? This is what you use to construct theta star. So let me not call it theta star. Let me call it theta hat. Because sometimes, they talk to statisticians, and they put hats on everything. OK, fine.
So what can go wrong, when you are trying to say, is this close to theta star? So theta star, for what it's worth, is like-- oh, god, it's like the expectation-- I don't even know-- of x expectation of x transpose x inverse y. I think that's probably true, not 100% sure.
But the point is you're just trying to estimate the mean of the empirical covariance matrix, the true covariance, by using just a plug-in estimator, just the empirical covariance. Yeah? So here's the point. This inverse-- so I'm going to channel my inner you. I don't know your name, but you're completely right.
So the trouble is, if you make a mistake in this inverse, things can go really wild. OK, just think about it like this. So let's draw the spectrum of this matrix x transpose x with the expectation. OK, so here's the spectrum of this matrix. This is the spectrum of the expectation of x transpose x.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Not necessarily. I'm just drawing for you. You have to make some assumption on your data. And I'm just saying x has some covariance. And I'm just drawing for you the eigenvalues of this covariance.
And the key point is I don't need to just estimate the expectation of x transpose x. I need to estimate the expectation of-- and then take the inverse of it. I really have to estimate the inverse. So the thing that's really going to screw me up is if I can't estimate the smallest eigenvalue. Because the smallest eigenvalue of a matrix becomes the biggest eigenvalue when you take the inverse.
So the point is, I don't actually get to observe the expectation of x transpose x. That's not how it works. I observe a noisy version of it, which is just x transpose x itself on my data. So you see there's, basically, two scales at play. There's lambda min, the smallest eigenvalue of the expectation of x transpose x.
AUDIENCE: [INAUDIBLE]
BORIS HANIN: Yeah. OK, so there's this scale. There's lambda min of the expectation of x transpose x. So I have to estimate this quantity. But there's also a noise that gets induced by observing some random data and trying to find the minimal eigenvalue. So there's none of these colors.
So what I really observe is lambda min of x transpose x. And I'd like to know, when are these two close? So really the question that you have to answer is, can I accurately estimate the smallest eigenvalue without being able to distinguish it from being 0? So the point that was just made is that, when the smallest eigenvalue is equal to 0, lambda min of the expectation of x transpose x is exactly when the number of data points is equal to the number of parameters. That's what the Marchenko-Pastur law tells you.
So exactly when you have a square matrix, the smallest eigenvalue is going to touch 0. And so it becomes harder and harder to accurately estimate what the smallest eigenvalue actually is. I only observe a noisy version of it.
OK, so basically, the heuristic is the following. It's very simple in the end. If the square root of the variance, so the standard deviation of lambda min, of x transpose x is much smaller than lambda min of the expectation of x transpose x, then you're OK. You get a good estimate for that inverse. OK, it's just I have to be able to distinguish lambda min from being 0. That's the only thing I have to do.
And OK, when you plug through the Marchenko-Pastur law, or whatever you want, you'll get a scaling limit, which tells you exactly how close the number of data and number of parameters have to be. That's really the most generic mechanism I know for a blow-up in the test error. Essentially, the variance blows up.
But it's not just the variance. It's that you don't have a good estimate for the inverse of your kernel. It's really hard to estimate to resolve the small eigenvalues. If you can't resolve the small eigenvalues, you may as well set them equal to 0, and you get infinite kind of predictions, really crazy, wild things.
OK, so I know I was slightly rushed. But I didn't want to keep us over time. I'm happy to talk about this offline, too. So let me just stop. I have lots of other things prepared, as I said. But I've tried to tell you what I think are some cool things you can take away to your everyday lives, so to speak, without having to prove theorems about neural networks. So OK, that's it. Yeah.
[APPLAUSE]