Example Codes for Saving Model Weights to MLflow with PyTorch Lightning
- PyTorch: A deep learning framework for building and training neural networks.
- PyTorch Lightning: A framework simplifying PyTorch development by handling boilerplate code like training loops, logging, and callbacks.
- MLflow: An open-source platform for managing the machine learning lifecycle, including experiment tracking, model registry, and deployment.
Direct Model Weight Saving (Limited Support):
- While PyTorch Lightning's autologging doesn't directly save model weights as artifacts, you can leverage the
mlflow.pytorch.save_model()
function:
import mlflow.pytorch
# ... your PyTorch Lightning training code ...
with mlflow.start_run() as run:
mlflow.pytorch.save_model(model, "model") # Save model weights to "model" artifact
Considerations:
- This approach saves the entire model, including weights and optimizer state.
- If you only need weights, consider manually saving them using
torch.save(model.state_dict(), ...)
and logging the file as an artifact.
Alternative: Checkpointing with MLflow Autologging (Recommended):
- PyTorch Lightning's autologging can automatically checkpoint models during training when
checkpoint_callback=True
is set in theTrainer
configuration. - These checkpoints include weights, optimizer state, and potentially other training information.
- Access these checkpoints within MLflow for evaluation or deployment.
from pytorch_lightning import Trainer
# ... your PyTorch Lightning model and trainer setup ...
trainer = Trainer(checkpoint_callback=True)
trainer.fit(model)
# Checkpoints are logged to MLflow by autologging
Key Points:
- For weight-only saving, consider manual saving and artifact logging.
- Checkpointing with autologging is generally recommended for comprehensive model state management during training.
- Use MLflow's UI or API to access these models for further analysis or deployment.
Example Codes for Saving Model Weights to MLflow with PyTorch Lightning
Saving Entire Model with mlflow.pytorch.save_model():
import mlflow.pytorch
import torch
class MyModel(torch.nn.Module):
# ... your model definition ...
# ... your PyTorch Lightning training code ...
with mlflow.start_run() as run:
# Train your model
# ...
# Save the entire model (weights and optimizer state) as an artifact
mlflow.pytorch.save_model(model, "model") # "model" is the artifact name
Explanation:
- Import necessary libraries (
mlflow.pytorch
andtorch
). - Define your PyTorch model class (
MyModel
). - Implement your PyTorch Lightning training logic (not shown here).
- Inside an
mlflow.start_run()
context:- Train your model.
- Use
mlflow.pytorch.save_model(model, "model")
to save the entire model as an artifact named "model" to the MLflow tracking server.
Saving Model Weights with Checkpointing and Autologging:
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
import torch
class MyModel(LightningModule):
# ... your model definition ...
def training_step(self, batch, batch_idx):
# ... your training logic ...
return loss
# ... other LightningModule methods (optional) ...
# ... your training data preparation ...
trainer = Trainer(checkpoint_callback=True) # Enable model checkpointing
trainer.fit(model, train_dataloaders)
# Checkpoints are automatically logged by autologging
- Import
Trainer
andLightningModule
frompytorch_lightning
. - Define your PyTorch Lightning model class (
MyModel
) withtraining_step
(and optionally other methods). - Prepare your training data (
train_dataloaders
). - Create a
Trainer
instance withcheckpoint_callback=True
to enable model checkpointing during training. - Train your model using
trainer.fit(model, train_dataloaders)
. - PyTorch Lightning's autologging automatically logs these checkpoints (including weights) as artifacts to the MLflow tracking server.
Remember:
- Choose the approach that aligns with your needs. For weight-only saving, consider approach 1. For comprehensive model state management during training, approach 2 with autologging is recommended.
- These examples provide the core structure. Customize the model definition, training logic, and data preparation according to your specific use case.
- Create a custom PyTorch Lightning callback that triggers weight saving at specific points during training.
- This approach offers more granular control over when to save weights.
Example:
from pytorch_lightning.callbacks import Callback
class WeightSaveCallback(Callback):
def on_epoch_end(self, trainer, pl_module):
# Save weights at the end of each epoch
torch.save(pl_module.model.state_dict(), f"weights_epoch_{trainer.current_epoch}.pth")
# Log the saved weights as an MLflow artifact (optional)
with mlflow.start_run() as run:
mlflow.log_artifact(f"weights_epoch_{trainer.current_epoch}.pth")
# ... PyTorch Lightning training code ...
trainer = Trainer(callbacks=[WeightSaveCallback()])
trainer.fit(model, train_dataloaders)
Using torch.save() with Manual Logging:
- Save model weights manually using
torch.save(model.state_dict(), ...)
at desired points in your training code. - Log the saved weight file as an MLflow artifact using
mlflow.log_artifact()
.
import torch
import mlflow
# ... PyTorch Lightning training code ...
# Save weights after training
torch.save(model.state_dict(), "saved_weights.pth")
# Log the saved weights as an MLflow artifact
with mlflow.start_run() as run:
mlflow.log_artifact("saved_weights.pth")
Third-Party Libraries:
- Explore libraries like
Sacred
orWeights & Biases
that integrate well with PyTorch Lightning and offer features for model versioning and artifact management, potentially including weight saving.
Choosing the Right Method:
- Consider the following factors when selecting an alternate method:
- Granularity: How often do you need to save weights?
- Control: Do you need fine-grained control over when weights are saved?
- Integration: How seamlessly does the method integrate with your PyTorch Lightning workflow?
pytorch