ZeroRedundancyOptimizer¶
- class mmengine.optim.ZeroRedundancyOptimizer(params, optimizer_type, **kwargs)[source]¶
A wrapper class of
ZeroRedundancyOptimizerthat gets a optimizer type as string.This class wraps an arbitrary
torch.optim.Optimizerand shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately1 / world_sizeparameters and hence only needs to keep1 / world_sizeoptimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state.ZeroRedundancyOptimizercan be used in conjunction withtorch.nn.parallel.DistributedDataParallelto reduce per-rank peak memory consumption.ZeroRedundancyOptimizeruses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.Warning
ZeroRedundancyOptimizerrequires PyTorch >= 1.8.Warning
ZeroRedundancyOptimizerrequires PyTorch >= 1.12 to enable param groups.- Parameters:
params (
Iterable) – anIterableoftorch.Tensors ordicts giving all parameters, which will be sharded across ranks.optimizer_type (str) – the string of the local optimizer class.