ALM
- class humancompatible.train.dual_optim.ALM(m: int = None, lr: float = 0.01, init_duals: float | Tensor = None, penalty: float = 1.0, *, dual_range: Tuple[float, float] = (-100.0, 100.0), momentum: float = 0.0, dampening: float = 0.0, is_ineq: bool = False, restart: bool = False, ctol: float = 0.0, device=None, process_group: ProcessGroup | None = None)
A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2504.07607
\[ \begin{align}\begin{aligned}\pmb{\lambda}_{t+1} & \leftarrow \pmb{\lambda}_t + \gamma \mathbf{c}_t(\theta_{t})\\\mathcal{L}_{t+1} & \leftarrow f_t(\theta_{t}) + \pmb{\lambda}_{t+1}^T \mathbf{c}_t(\theta_{t}) + \frac{\rho}{2} \| \mathbf{c}_t(\theta_{t}) \|^2_2\end{aligned}\end{align} \]- Parameters:
m (int) – Number of constraints (determines the number of dual variables to create)
lr (float) – Dual variable update rate.
init_duals (float | Tensor) – Initial values for the new dual variables. Defaults to 0 for all.
penalty (float) – Augmented Lagrangian penalty parameter. Defaults to`1.`
dual_range (Tuple[float, float]) – Safeguarding range for dual variables; they will be`clamp`-ed to this range.
momentum (float) – Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to 0 to disable.
dampening (float) – Dampening for momentum. Equivalent to SGD dampening. Set to 0 to disable.
is_ineq (bool) – Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by max(dual_range[0], 0).
restart (bool) – Whether to set the dual variables to zero immediately on strict satisfaction of corresponding constraints. Not recommended for stochastic constraints.
ctol (float) – Constraint tolerance; allows tiny violations of constraints to account for noise.
process_group (dist.ProcessGroup, optional) – Distributed process group for DDP. When set, constraint values are averaged across all workers via
dist.all_reducebefore each dual update, keeping dual variables consistent across replicas. Defaults toNone(no synchronization).
- add_constraint_group(m: int, lr: float = None, momentum: float = None, dampening: float = None, init_duals: Tensor = None, dual_range: tuple[float, float] = None, is_ineq: bool = False, restart: bool = False, device=None) None
Allows to add a group of dual variables with separate initial values and learning rates.
- Parameters:
m (int) – Size of group (number of dual variables to add)
lr (float) – Dual variable update rate.
momentum (float) – Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to 0 to disable.
dampening (float) – Dampening for momentum. Equivalent to SGD dampening. Set to 0 to disable.
init_duals (Tensor) – Initial values for the new dual variables. Defaults to the value set when creating the optimizer.
dual_range (Tuple[float, float]) – After each dual update, the dual variables will be clamped to this range.
is_ineq (bool) – Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by max(dual_range[0], 0).
restart (bool) – Whether to set the dual variables to zero immediately on strict satisfaction of corresponding constraints. Not recommended for stochastic constraints.
Note
Parameters here will default to values set when initializing the dual optimizer.
- property duals: Tensor
- Returns:
Dual variables, concatenated into a single tensor.
- Return type:
Tensor
- forward(loss: Tensor, constraints: Tensor) Tensor
Calculates and returns the Augmented Lagrangian.
Computes the augmented Lagrangian:
L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2
where loss is the objective value, duals_i are the dual variables, constraints_i are constraint values, penalty is the penalty parameter, and the sum is over all constraint groups.
- Parameters:
loss (Tensor) – Loss (objective function) value
constraints (Tensor) – Tensor of constraint values
- Returns:
Lagrangian
- Return type:
Tensor
- forward_update(loss: Tensor, constraints: Tensor) Tensor
Combines forward and update; slightly faster than calling both separately.
Updates dual variables:
duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound)
Then computes the augmented Lagrangian:
L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2
where the momentum buffer is updated as in
update().- Parameters:
loss (Tensor) – Loss (objective function) value
constraints (Tensor) – Tensor of constraint values
- Returns:
Lagrangian
- Return type:
Tensor
- step(constraints: Tensor) None
Updates the dual variables using constrained gradient ascent with optional momentum.
For each constraint group, performs the dual variable update.
First, update the momentum buffer (if momentum > 0):
if momentum > 0: buffer_i = momentum * buffer_i + (1 - dampening) * constraints_i else: buffer_i = constraints_i
Then, update the dual variables with clamping:
duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound)
where buffer_i is the momentum buffer, constraints_i are constraint values, duals_i are dual variables, and clamp(x, lb, ub) projects to the dual range.
- Parameters:
constraints (Tensor) – Tensor of constraint values
- update(constraints: Tensor) None
Updates the dual variables using constrained gradient ascent with optional momentum.
For each constraint group, performs the dual variable update.
First, update the momentum buffer (if momentum > 0):
if momentum > 0: buffer_i = momentum * buffer_i + (1 - dampening) * constraints_i else: buffer_i = constraints_i
Then, update the dual variables with clamping:
duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound)
where buffer_i is the momentum buffer, constraints_i are constraint values, duals_i are dual variables, and clamp(x, lb, ub) projects to the dual range.
- Parameters:
constraints (Tensor) – Tensor of constraint values