Linear Neural Networks and Regularization
Introduction
Regression problems are ubiquitous in machine learning, serving as the foundation for predicting numerical values based on input features. While modern deep learning architectures can be incredibly complex, they often share the same underlying components as simple linear models: a parametric form, a differentiable objective function, and an optimization algorithm. This document focuses on the linear neural network, a shallow architecture where inputs connect directly to outputs. We will define the model formally, examine its implementation, and discuss the inevitable challenges of generalization, including how to detect and resolve underfitting and overfitting using techniques like weight decay.
1. Definition of Linear Neural Networks
Linear regression is the simplest tool for regression tasks, where the goal is to predict a continuous numerical value. In the context of neural networks, linear regression can be viewed as a single-layer network where every input feature is connected directly to the output via a weighted sum. We assume the relationship between input features \(\mathbf{x}\) and the target \(y\) is approximately linear, subject to additive noise.
Given a dataset of \(n\) examples where each example consists of \(d\) features, let \(\mathbf{x}^{(i)} \in \mathbb{R}^d\) denote the features of the \(i\)-th example and \(y^{(i)} \in \mathbb{R}\) denote its label. The linear regression model predicts the target \(\hat{y}^{(i)}\) as a weighted sum of the inputs plus a bias term: \[ \hat{y}^{(i)} = \mathbf{w}^\top \mathbf{x}^{(i)} + b \] where \(\mathbf{w} \in \mathbb{R}^d\) is the weight vector and \(b \in \mathbb{R}\) is the bias (or offset).
The bias \(b\) allows the model to represent linear functions that do not pass through the origin. This transformation is strictly an affine transformation, though it is commonly referred to as a linear model in the context of neural networks. Vectorizing this over the entire dataset design matrix \(\mathbf{X} \in \mathbb{R}^{n \times d}\), we can write \(\hat{\mathbf{y}} = \mathbf{X}\mathbf{w} + b\).
The goal of learning is to find the parameters \(\mathbf{w}\) and \(b\) that minimize the discrepancy between the predicted values \(\hat{y}\) and the true labels \(y\).
To quantify this discrepancy, we define a loss function that provides a non-negative number representing the error.
The squared error loss for a single example \(i\) is defined as: \[ l^{(i)}(\mathbf{w}, b) = \frac{1}{2} \left(\hat{y}^{(i)} - y^{(i)}\right)^2 \] The total loss on the training dataset is the average of the individual losses: \[ L(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^n l^{(i)}(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^n \frac{1}{2} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2 \]
The factor of \(1/2\) is included for mathematical convenience; it cancels out when taking the derivative of the quadratic loss function, leaving a cleaner expression for the gradient. This loss corresponds to the assumption that the data was generated with additive Gaussian noise.
2. Implementation of Linear Regression
While linear regression has an analytic closed-form solution, implementing it using gradient-based optimization allows us to extend the approach to more complex, non-linear neural networks later. We utilize high-level APIs from deep learning frameworks to define the model, loss, and optimizer concisely.
In this implementation, nn.LazyLinear(1) defines a transformation equivalent to \(\mathbf{w}^\top \mathbf{x} + b\) producing a single scalar output. The configure_optimizers method specifies that we wish to update the model parameters using SGD with a learning rate lr.
3. Non-linear Curve Fitting
A linear model is not restricted to learning relationships that are strictly linear in the raw input variables. By transforming the input features, we can use the machinery of linear regression to fit non-linear curves, such as polynomials. This is a powerful concept where we perform linear regression in a transformed feature space.
Given a scalar input \(x \in \mathbb{R}\) and target \(y \in \mathbb{R}\), polynomial regression models \(y\) as a \(k\)-th degree polynomial of \(x\): \[ \hat{y} = w_1 x + w_2 x^2 + \dots + w_k x^k + b = \sum_{j=1}^k w_j x^j + b \]
Although \(\hat{y}\) is non-linear with respect to the scalar input \(x\), it is linear with respect to the coefficients \(w_1, \dots, w_k\). We can treat the powers of \(x\) (\(x, x^2, \dots, x^k\)) as distinct features. If we define a feature vector \(\mathbf{\phi}(x) = [x, x^2, \dots, x^k]^\top\), the model becomes the standard linear regression form \(\hat{y} = \mathbf{w}^\top \mathbf{\phi}(x) + b\). This allows us to fit non-linear curves using the same linear regression algorithms described previously.
4. Training Issues: Underfitting and Overfitting
The central problem of machine learning is generalization: learning patterns that apply to unseen data rather than merely memorizing the training data. We evaluate models based on two metrics: training error (error on the data used to learn parameters) and validation error (error on held-out data). The relationship between these errors defines two primary failure modes.
- Underfitting occurs when a model is too simple to capture the underlying pattern of the data.
- Overfitting occurs when a model learns the noise or specific idiosyncrasies of the training data rather than the general distribution.
We will now analyze the symptoms and solutions for each case.
5. Underfitting
Underfitting is the phenomenon where the model fails to reduce the training error to a sufficiently low level.
In this state, both the training error and the validation error are high, and the gap between them is small.
Underfitting typically happens when the model lacks sufficient capacity (complexity) to represent the data. For example, trying to fit a high-degree polynomial curve using a straight line (linear model) will result in underfitting. Detecting underfitting is straightforward: if the model cannot achieve a low error rate even on the training set, it is underfitting.
To fix underfitting, we must increase the model’s complexity. This can be done by:
- Adding more features (e.g., polynomial terms).
- Using a more complex architecture (e.g., deeper neural networks).
6. Overfitting
Overfitting is the phenomenon where the training error is significantly lower than the validation error.
This indicates that the model has memorized the training set but fails to generalize to new data. The generalization gap (\(R_{\text{emp}} - R\)) becomes large.
Overfitting typically happens when the model is too complex relative to the amount of available training data. For example, fitting a high-degree polynomial to a small dataset might result in a curve that passes through every training point perfectly but oscillates wildly in between. We detect overfitting by observing a diverging gap between the training loss (which decreases) and the validation loss (which stagnates or increases).
To avoid or mitigate overfitting, we can:
- Increase the amount of training data.
- Reduce the model complexity (e.g., use fewer parameters).
- Apply regularization techniques such as weight decay.
7. Regularization using Weight Decay
Often, we cannot simply increase the dataset size or manually select the perfect model complexity. Instead, we use regularization to constrain the effective complexity of the model during training. The most common technique is weight decay (also known as \(L_2\) regularization).
Weight decay works by adding a penalty term to the loss function that discourages large weights. The modified loss function is: \[ L(\mathbf{w}, b) + \frac{\lambda}{2} |\mathbf{w}|^2 \] where \(|\mathbf{w}|^2 = \sum_{j} w_j^2\) is the squared \(L_2\) norm of the weight vector, and \(\lambda \ge 0\) is a regularization hyperparameter.
The term \(\frac{\lambda}{2} |\mathbf{w}|^2\) forces the optimization algorithm to balance two goals: minimizing the prediction error and keeping the weights small. The constant \(\lambda\) controls this trade-off. If \(\lambda = 0\), we have standard linear regression. As \(\lambda\) increases, the weights are constrained more heavily.
Why does this avoid overfitting? A model with large weights can change its output very rapidly for small changes in input, allowing it to fit noise and outliers (high complexity). By forcing weights to be small (decaying them towards zero), we ensure the function remains smoother and simpler, effectively reducing the model’s capacity to memorize noise.
8. Summary of Key Concepts
- Linear Model: \(\hat{y} = \mathbf{w}^\top \mathbf{x} + b\).
- Squared Loss: Minimizes \(\frac{1}{2}(\hat{y} - y)^2\).
- Underfitting: High training error, high validation error. Solution: Increase complexity.
- Overfitting: Low training error, high validation error. Solution: Regularization (Weight Decay).
- Weight Decay: Adds \(\frac{\lambda}{2} |\mathbf{w}|^2\) to the loss to penalize large weights.
These concepts form the bedrock of statistical learning and are prerequisites for understanding deep neural networks.
Conclusion
In this document, we established the mathematical foundation of linear neural networks, defining the prediction model and the squared loss function. We demonstrated that despite its simplicity, the linear model can be implemented using the same powerful framework components—modules, losses, and optimizers—used in deep learning. We also explored how linear models can address non-linear problems via polynomial features. Finally, we categorized the primary challenges of training: underfitting, which requires more model capacity, and overfitting, which requires regularization techniques like weight decay to constrain model complexity and improve generalization.