Optimizing Training: A Guide to Constructing Parameter Groups in PyTorch
In PyTorch, optimizers handle updates to model parameters during training. When you create an optimizer, you can optionally group the model's parameters into distinct sets called parameter groups. Each parameter group can have its own set of hyperparameters, such as the learning rate, momentum, or weight decay. This allows you to fine-tune the optimization process for different parts of your model.
Creating Parameter Groups
There are two main ways to define parameter groups:
-
Using a List of Dictionaries:
- Construct a list where each element is a dictionary representing a parameter group.
- The dictionary keys specify the hyperparameters you want to adjust for that group (e.g.,
lr
,momentum
,weight_decay
). - The dictionary values define the corresponding values for those hyperparameters.
- Pass this list to the optimizer's constructor.
param_groups = [ {'params': model.base.parameters(), 'lr': 1e-2}, # Lower LR for base layers {'params': model.classifier.parameters(), 'lr': 1e-3} # Higher LR for classifier ] optimizer = torch.optim.SGD(param_groups)
-
Using an Iterator of Tensors or a Single Dictionary (Default):
Benefits of Parameter Groups
- Fine-Tuning Learning Rates: You can apply different learning rates to different parts of your model. For example, you might use a lower learning rate for layers that are already well-trained (e.g., pre-trained models) and a higher learning rate for layers you want to adapt more (e.g., newly added layers).
- Adjusting Other Hyperparameters: You can customize other optimization hyperparameters (like momentum or weight decay) for specific parameter groups.
- Modular Learning Rate Control: It simplifies adjusting learning rates during training. You can modify specific parameter groups' learning rates without affecting the entire model.
Example: Fine-Tuning a Pre-trained Model
Imagine you're fine-tuning a pre-trained convolutional base (e.g., VGG16) for a new image classification task. You might want to keep the base layers' weights relatively stable while allowing the newly added classifier layer to learn more freely. By using parameter groups, you can set a lower learning rate for the base layers and a higher learning rate for the classifier layer.
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# Load a pre-trained model (e.g., VGG16) and freeze its parameters
self.base = nn.Sequential(...) # Replace with your pre-trained model
for param in self.base.parameters():
param.requires_grad = False
# Add a new classifier layer
self.classifier = nn.Linear(1000, 10) # Assuming pre-trained model has 1000 output features
def forward(self, x):
x = self.base(x)
x = self.classifier(x)
return x
# Create the model
model = MyModel()
# Define parameter groups
param_groups = [
{'params': model.base.parameters(), 'lr': 1e-3}, # Low LR for frozen base
{'params': model.classifier.parameters(), 'lr': 1e-2} # Higher LR for classifier
]
# Create the optimizer with parameter groups
optimizer = torch.optim.SGD(param_groups)
Applying Different Learning Rates to Different Layers:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 28 * 28) # Flatten input if needed
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
# Create the model
model = MyModel()
# Define parameter groups (all parameters are separate groups here)
param_groups = [
{'params': model.fc1.parameters()},
{'params': model.fc2.parameters(), 'lr': 0.01}, # Different LR for fc2
{'params': model.fc3.parameters()}
]
# Create the optimizer with parameter groups
optimizer = torch.optim.Adam(param_groups)
This approach involves creating a separate nn.ModuleDict
to store your model's components grouped logically. Then, within each group, you can use nn.ModuleList
to hold the individual layers. This allows you to easily access parameters for specific groups using the module hierarchy.
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.base = nn.Sequential(nn.Linear(784, 128), nn.ReLU())
self.classifier = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 10))
self.params = nn.ModuleDict({
'base': self.base,
'classifier': self.classifier
})
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.params['base'](x)
x = self.params['classifier'](x)
return x
# Create the model
model = MyModel()
# Access parameters for specific groups
base_params = model.params['base'].parameters()
classifier_params = model.params['classifier'].parameters()
# Create the optimizer with separate calls
optimizer = torch.optim.SGD([
{'params': base_params, 'lr': 1e-3},
{'params': classifier_params, 'lr': 1e-2}
])
Using Submodules:
You can define different sections of your model as submodules within the main model class. This allows you to access and potentially set different hyperparameters for each submodule's parameters directly through the optimizer.
import torch
import torch.nn as nn
class BaseModule(nn.Module):
def __init__(self):
super(BaseModule, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
class ClassifierModule(nn.Module):
def __init__(self):
super(ClassifierModule, self).__init__()
self.fc3 = nn.Linear(64, 10)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.base = BaseModule()
self.classifier = ClassifierModule()
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.base(x)
x = self.classifier(x)
return x
# Create the model
model = MyModel()
# Create the optimizer with different hyperparameters for submodules
optimizer = torch.optim.SGD([
{'params': model.base.parameters(), 'lr': 1e-3},
{'params': model.classifier.parameters(), 'lr': 1e-2}
])
Choosing the Right Method:
- The method using a list of dictionaries offers a concise and flexible way to define parameter groups.
- The
nn.ModuleDict
andnn.ModuleList
approach provides a clear organization within your model structure and easy access to parameters. - Using submodules can be helpful if you have well-defined, reusable components in your model.
pytorch