Basic Usage
This page provides an overview of using humancompatible-train for constrained deep learning on a simple example.
Idea
The core of the package is formed by Lagrangian-based dual optimizers, which are PyTorch Optimizer-like objects that handle the constrained part of constrained deep learning.
They create, keep track of, and update the dual parameters of the constrained minimization problem, as well as calculate the Lagrangian that is then minimized by a standard PyTorch optimizer in place of a loss.
Simple Example
Let us demonstrate using a fairness-constrained learning task, where we want to learn a classifier that is accurate but also satisfies a demographic parity constraint - i.e. we would like
where \( Y \) is the prediction given by our model for sample \( X \), and \( \epsilon \) is some small threshold.
To enforce demographic parity, we will define a constraint function (using the fairret package) that measures the difference in positive prediction rates between two demographic groups.
The dual optimizer will then update the Lagrange multipliers to enforce this constraint during training.
First, let us load and prepare the data. We will use the ACS dataset, containing U.S. Census data, provided by the folktables package. Feel free to skip this section.
Initialize the model and optimizer.
from torch.nn import Sequential
from torch.optim import AdamW
def setup_model():
model = Sequential(
torch.nn.Linear(features.shape[1], 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 1),
)
model.forward(torch.zeros(features.shape[1])).backward() # dummy forward/backward pass to construct torch graph for fair comparison
optimizer = AdamW(model.parameters())
return model, optimizer
Next, we define the constraint function for demographic parity, which uses the fairret.statistic.PositiveRate class to evaluate positive rates for both groups.
As a reminder, we expect our constraints to be of the form \( g(...) \leq 0 \) or \( h(...) = 0 \). We want \( g(...) \leq \epsilon \), so we will subtract \( \epsilon \) in the training loop.
from fairret.statistic import PositiveRate
statistic = PositiveRate()
def pr_diff(logit, groups):
preds = torch.sigmoid(logit)
stats = PositiveRate()(preds, groups)
stat_diff = torch.abs(stats[0] - stats[1])
return stat_diff
As a last step, we define our dual optimizer. To set it up, we only need to define the number of constraints – in our case, it is 1 – so it can create the corresponding dual variables, and the type of constraint – equality or inequality. In a following tutorial, we will see how to create constraint groups with different types and hyperparameters.
from humancompatible.train.dual_optim import ALM
dual_optimizer = ALM(m=1, lr=0.01, is_ineq=True)
Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our \( \epsilon \) threshold).
Then, the forward_update step does two things:
Updates the dual variables based on the constraint violation,
Calculates the Lagrangian based on loss and constraint violation.
We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer.
model, optimizer = setup_model()
epochs = 10
for epoch in range(epochs):
# eval
model.eval()
logit = model(X)
train_loss = criterion(logit, y).item()
train_fair = pr_diff(logit, groups).item()
print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}")
# train
model.train()
for batch_feat, batch_groups, batch_label in loader:
optimizer.zero_grad()
logit = model(batch_feat)
loss = criterion(logit, batch_label)
constraint = pr_diff(logit, batch_groups) - 0.05
lagr = dual_optimizer.forward_update(loss, constraint)
lagr.backward()
optimizer.step()
Epoch: 0, loss: 0.7075298428535461, constraint: 0.001788794994354248
Epoch: 1, loss: 0.3985254466533661, constraint: 0.06790915131568909
Epoch: 2, loss: 0.3951018452644348, constraint: 0.05005711317062378
Epoch: 3, loss: 0.38819023966789246, constraint: 0.048632800579071045
Epoch: 4, loss: 0.3837517499923706, constraint: 0.035653531551361084
Epoch: 5, loss: 0.37301546335220337, constraint: 0.046016424894332886
Epoch: 6, loss: 0.37078288197517395, constraint: 0.037894248962402344
Epoch: 7, loss: 0.3614709675312042, constraint: 0.039066046476364136
Epoch: 8, loss: 0.35366758704185486, constraint: 0.03798931837081909
Epoch: 9, loss: 0.34534910321235657, constraint: 0.035340964794158936
We obtain a respectable loss value, while keeping the fairness violation below the threshold!
Just in case, let’s check what happens if we train the model without constraints:
model, optimizer = setup_model()
for epoch in range(epochs):
# eval
model.eval()
logit = model(X)
train_loss = criterion(logit, y).item()
train_fair = pr_diff(logit, groups).item()
print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}")
# train
model.train()
for batch_feat, batch_groups, batch_label in loader:
optimizer.zero_grad()
logit = model(batch_feat)
loss = criterion(logit, batch_label)
loss.backward()
optimizer.step()
Epoch: 0, loss: 0.6862483620643616, constraint: 0.0014756619930267334
Epoch: 1, loss: 0.3999955654144287, constraint: 0.0830242931842804
Epoch: 2, loss: 0.3909071087837219, constraint: 0.0868237316608429
Epoch: 3, loss: 0.3815554976463318, constraint: 0.09362807869911194
Epoch: 4, loss: 0.3745419383049011, constraint: 0.09206011891365051
Epoch: 5, loss: 0.3680148720741272, constraint: 0.09046411514282227
Epoch: 6, loss: 0.3566308915615082, constraint: 0.09541821479797363
Epoch: 7, loss: 0.3507809340953827, constraint: 0.09349411725997925
Epoch: 8, loss: 0.3389263451099396, constraint: 0.10273250937461853
Epoch: 9, loss: 0.33016437292099, constraint: 0.10011604428291321
The absolute difference in positive rates is two times higher than what we wanted!
Further reading: