Beyond One-Hot Encoding: `torch.embedding` and Efficient Text Representation in PyTorch
In PyTorch, torch.embedding
(part of the torch.nn
module) is a building block used in neural networks, specifically for tasks that involve categorical data like words, characters, or other discrete symbols. It efficiently represents these categories as dense vectors in a lower-dimensional space, capturing semantic relationships between them.
How Does torch.embedding
Work?
-
Input and Output:
- Input:
torch.embedding
takes a tensor of Long (torch.long
) data type, where each element represents the index of a category in the vocabulary. - Output: It returns a tensor where each row corresponds to the embedding vector for the category index in the input tensor.
- Input:
Benefits of Using torch.embedding
:
- Reduces Dimensionality: Embedding vectors are typically lower-dimensional than one-hot encoded representations, improving computational efficiency, especially for large vocabularies.
- Captures Relationships: The learned embedding vectors can encode semantic relationships between categories, such as words with similar meanings having similar embedding vectors.
Where's the Definition Located?
The core implementation of torch.embedding
is written in C++ for performance reasons. While you won't find the exact source code definition in the PyTorch documentation, you can explore its usage and behavior through the official documentation:
Example Usage:
import torch
import torch.nn as nn
# Example vocabulary of 5 words
vocabulary_size = 5
embedding_dim = 10 # Dimensionality of embedding vectors
# Create an embedding layer
embedding = nn.Embedding(vocabulary_size, embedding_dim)
# Example input: tensor([2, 1, 0]) representing indices of 3 words
input_tensor = torch.tensor([2, 1, 0], dtype=torch.long)
# Get the embedding vectors
output = embedding(input_tensor)
print(output.shape) # Output shape: torch.Size([3, 10])
In this example, the embedding
layer maps the indices [2, 1, 0]
to three embedding vectors of size (1, 10)
, representing the words at those indices in the vocabulary.
import torch
import torch.nn as nn
# Vocabulary size (number of unique words)
vocabulary_size = 1000
# Dimensionality of embedding vectors (how many features represent each word)
embedding_dim = 128
# Create an embedding layer
embedding = nn.Embedding(vocabulary_size, embedding_dim)
# Example input: a tensor of word indices (e.g., from a sentence)
word_indices = torch.tensor([23, 789, 12], dtype=torch.long)
# Get the embedding vectors for the given word indices
word_embeddings = embedding(word_indices)
print(word_embeddings.shape) # Output shape: torch.Size([3, 128])
This code defines an embedding layer with a vocabulary size of 1000 and an embedding dimension of 128. It then creates a sample tensor containing word indices (e.g., from a sentence) and uses the embedding layer to retrieve the corresponding embedding vectors.
Example 2: Pre-trained Embedding Usage
import torch
from torch.utils.data import TensorDataset, DataLoader
# Load pre-trained embeddings (replace with your actual loading method)
pre_trained_embeddings = torch.randn(1000, 128) # Example pre-trained embeddings
# Define a function to map word indices to embedding vectors using pre-trained ones
def get_embedding(word_index):
if word_index < 1000: # Check if word index is within pre-trained vocabulary
return pre_trained_embeddings[word_index]
else:
# Handle out-of-vocabulary words (e.g., random initialization)
return torch.randn(128) # Placeholder for unknown words
# Sample data (word indices)
word_indices = torch.tensor([23, 789, 1567], dtype=torch.long)
# Create a DataLoader to process the data in batches (optional for larger datasets)
dataset = TensorDataset(word_indices)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Use the get_embedding function to get embeddings for each batch
for batch_indices in data_loader:
batch_embeddings = []
for word_index in batch_indices[0]: # Assuming batch_indices is a tuple containing word indices
batch_embeddings.append(get_embedding(word_index))
batch_embeddings = torch.stack(batch_embeddings) # Combine embeddings into a tensor
# Use batch_embeddings for further processing in your model
# ... (Your model code here)
This example demonstrates using pre-trained embeddings. It defines a function get_embedding
that checks if the word index exists in the pre-trained vocabulary. If it does, it returns the corresponding embedding vector. Otherwise, it handles out-of-vocabulary words (e.g., by initializing a random vector).
- Create a Python dictionary where keys are category indices and values are embedding vectors.
- During training, you'll manually update these embedding vectors based on your learning rule.
- This approach offers more control over the lookup process but can be memory-intensive for large vocabularies.
One-Hot Encoding:
- Convert categorical data into one-hot encoded vectors, where each element represents category membership (1 for the relevant category, 0 for others).
- While this is simple to implement, it can lead to high dimensionality and inefficiency, especially with large vocabularies.
Pre-trained Embeddings (without torch.embedding):
- Load pre-trained embeddings from external sources like Word2Vec or GloVe.
- Access the embedding vectors for your categories using indexing based on word or symbol names.
- This leverages pre-trained knowledge but requires careful handling of out-of-vocabulary words.
Here's a code example for a custom lookup table:
import torch
# Vocabulary size
vocabulary_size = 1000
# Embedding dimension
embedding_dim = 128
# Create a dictionary as a lookup table
embeddings = {}
for i in range(vocabulary_size):
embeddings[i] = torch.randn(embedding_dim) # Initialize random embeddings
# Example input: a tensor of word indices
word_indices = torch.tensor([23, 789, 12], dtype=torch.long)
# Get embedding vectors using the lookup table
word_embeddings = []
for index in word_indices:
word_embeddings.append(embeddings[index])
word_embeddings = torch.stack(word_embeddings) # Combine into a tensor
print(word_embeddings.shape) # Output shape: torch.Size([3, 128])
Choosing the right approach depends on factors like:
- Dataset size and vocabulary:
torch.embedding
is generally more efficient for large datasets. - Control over learning: Custom lookup tables offer more control but require manual updates.
- Pre-trained knowledge utilization: Pre-trained embeddings can be beneficial if appropriate ones exist.
pytorch