Skip to content

Commit 457850d

Browse files
authored
[zero] prevent poor configs from running w. zero-offload (deepspeedai#2971)
1 parent 58a4a4d commit 457850d

File tree

5 files changed

+60
-1
lines changed

5 files changed

+60
-1
lines changed

deepspeed/runtime/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,12 @@ def get_zero_allow_untested_optimizer(param_dict):
502502
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)
503503

504504

505+
def get_zero_force_ds_cpu_optimizer(param_dict):
506+
return get_scalar_param(param_dict,
507+
ZERO_FORCE_DS_CPU_OPTIMIZER,
508+
ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT)
509+
510+
505511
def get_scheduler_name(param_dict):
506512
if SCHEDULER in param_dict.keys() and TYPE in param_dict[SCHEDULER].keys():
507513
return param_dict[SCHEDULER][TYPE]
@@ -859,6 +865,8 @@ def _initialize_params(self, param_dict):
859865
self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(
860866
param_dict)
861867

868+
self.zero_force_ds_cpu_optimizer = get_zero_force_ds_cpu_optimizer(param_dict)
869+
862870
self.scheduler_name = get_scheduler_name(param_dict)
863871
self.scheduler_params = get_scheduler_params(param_dict)
864872

deepspeed/runtime/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
#############################################
7474
ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer"
7575
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False
76+
ZERO_FORCE_DS_CPU_OPTIMIZER = "zero_force_ds_cpu_optimizer"
77+
ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT = True
7678

7779
# Steps
7880
STEPS_PER_PRINT = "steps_per_print"

deepspeed/runtime/engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,9 @@ def zero_optimization(self):
719719
def zero_allow_untested_optimizer(self):
720720
return self._config.zero_allow_untested_optimizer
721721

722+
def zero_force_ds_cpu_optimizer(self):
723+
return self._config.zero_force_ds_cpu_optimizer
724+
722725
def zero_reduce_scatter(self):
723726
return self._config.zero_config.reduce_scatter
724727

@@ -1265,6 +1268,13 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
12651268
else:
12661269
basic_optimizer = client_optimizer(model_parameters)
12671270
log_dist('Using client callable to create basic optimizer', ranks=[0])
1271+
1272+
if self.zero_use_cpu_optimizer() and not isinstance(
1273+
basic_optimizer,
1274+
deepspeed.ops.adam.DeepSpeedCPUAdam):
1275+
if self.zero_force_ds_cpu_optimizer():
1276+
msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
1277+
raise ZeRORuntimeException(msg)
12681278
else:
12691279
basic_optimizer = self._configure_basic_optimizer(model_parameters)
12701280
log_dist(

tests/unit/runtime/half_precision/test_fp16.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ def test(self, zero_stage, use_cpu_offload):
466466
"stage": zero_stage,
467467
"cpu_offload": use_cpu_offload
468468
},
469-
"zero_allow_untested_optimizer": False
469+
"zero_allow_untested_optimizer": False,
470+
"zero_force_ds_cpu_optimizer": False
470471
}
471472
hidden_dim = 10
472473

tests/unit/runtime/zero/test_zero.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from deepspeed.runtime.engine import DeepSpeedEngine
1919
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
2020
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
21+
from deepspeed.runtime.zero.utils import ZeRORuntimeException
2122
from deepspeed.accelerator import get_accelerator
2223

2324

@@ -1384,3 +1385,40 @@ def forward(self, x, y):
13841385
loss = loss[1]
13851386
model.backward(loss)
13861387
model.step()
1388+
1389+
1390+
@pytest.mark.parametrize('force_ds_optim', [True, False])
1391+
class TestZeroOffloadOptim(DistributedTest):
1392+
world_size = 1
1393+
1394+
def test(self, force_ds_optim):
1395+
config_dict = {
1396+
"train_batch_size": 4,
1397+
"gradient_accumulation_steps": 2,
1398+
"steps_per_print": 1,
1399+
"fp16": {
1400+
"enabled": True
1401+
},
1402+
"zero_optimization": {
1403+
"stage": 1,
1404+
"offload_optimizer": {
1405+
"device": "cpu"
1406+
}
1407+
},
1408+
"zero_force_ds_cpu_optimizer": force_ds_optim,
1409+
}
1410+
hidden_dim = 10
1411+
1412+
model = SimpleModel(hidden_dim)
1413+
1414+
optimizer = torch.optim.Adam(model.parameters())
1415+
1416+
if force_ds_optim:
1417+
with pytest.raises(ZeRORuntimeException):
1418+
model, _, _, _ = deepspeed.initialize(model=model,
1419+
optimizer=optimizer,
1420+
config=config_dict)
1421+
else:
1422+
model, _, _, _ = deepspeed.initialize(model=model,
1423+
optimizer=optimizer,
1424+
config=config_dict)

0 commit comments

Comments
 (0)