BaseModel¶
- class mmengine.model.BaseModel(data_preprocessor=None, init_cfg=None)[source]¶
Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information in
BaseDataPreprocessor), parse losses, and update model parameters.Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.
Examples
>>> @MODELS.register_module() >>> class ToyModel(BaseModel): >>> >>> def __init__(self): >>> super().__init__() >>> self.backbone = nn.Sequential() >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) >>> >>> self.criterion = nn.CrossEntropyLoss() >>> >>> def forward(self, batch_inputs, data_samples, mode='tensor'): >>> data_samples = torch.stack(data_samples) >>> if mode == 'tensor': >>> return self.backbone(batch_inputs) >>> elif mode == 'predict': >>> feats = self.backbone(batch_inputs) >>> predictions = torch.argmax(feats, 1) >>> return predictions >>> elif mode == 'loss': >>> feats = self.backbone(batch_inputs) >>> loss = self.criterion(feats, data_samples) >>> return dict(loss=loss)
- Parameters:
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor.init_cfg (dict, optional) – The weight initialized config for
BaseModule.
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward().- Type:
- cpu(*args, **kwargs)[source]¶
Overrides this method to call
BaseDataPreprocessor.cpu()additionally.- Returns:
The model itself.
- Return type:
nn.Module
- cuda(device=None)[source]¶
Overrides this method to call
BaseDataPreprocessor.cuda()additionally.
- abstract forward(inputs, data_samples=None, mode='tensor')[source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forwardmethod of BaseModel is an abstract method, its subclasses must implement this method.Accepts
batch_inputsanddata_sampleprocessed bydata_preprocessor, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forwardwill be called byBaseModel.train_step,BaseModel.val_stepandBaseModel.test_stepdirectly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_stepwill first callDistributedDataParallel.forwardto enable automatic gradient synchronization, and then callforwardto get training loss.- Parameters:
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (list, optional) – data samples collated by
data_preprocessor.mode (str) –
mode should be one of
loss,predictandtensorloss: Called bytrain_stepand return lossdictused for loggingpredict: Called byval_stepandtest_stepand return list of results used for computing metric.tensor: Called by custom use to getTensortype results.
- Returns:
If
mode == loss, return adictof loss tensor used for backward and logging.If
mode == predict, return alistof inference results.If
mode == tensor, return a tensor ortupleof tensor ordictof tensor for custom use.
- Return type:
- mlu(device=None)[source]¶
Overrides this method to call
BaseDataPreprocessor.mlu()additionally.
- musa(device=None)[source]¶
Overrides this method to call
BaseDataPreprocessor.musa()additionally.
- npu(device=None)[source]¶
Overrides this method to call
BaseDataPreprocessor.npu()additionally.Note
This generation of NPU(Ascend910) does not support the use of multiple cards in a single process, so the index here needs to be consistent with the default device
- parse_losses(losses)[source]¶
Parses the raw outputs (losses) of the network.
- Parameters:
losses (dict) – Raw output of the network, which usually contain losses and other necessary information.
- Returns:
There are two elements. The first is the loss tensor passed to optim_wrapper which may be a weighted sum of all losses, and the second is log_vars which will be sent to the logger.
- Return type:
- to(*args, **kwargs)[source]¶
Overrides this method to call
BaseDataPreprocessor.to()additionally.- Returns:
The model itself.
- Return type:
nn.Module
- train_step(data, optim_wrapper)[source]¶
Implements the default model training process including preprocessing, model forward propagation, loss calculation, optimization, and back-propagation.
During non-distributed training. If subclasses do not override the
train_step(),EpochBasedTrainLooporIterBasedTrainLoopwill call this method to update model parameters. The default parameter update process is as follows:Calls
self.data_processor(data, training=False)to collect batch_inputs and corresponding data_samples(labels).Calls
self(batch_inputs, data_samples, mode='loss')to get raw lossCalls
self.parse_lossesto getparsed_lossestensor used to backward and dict of loss tensor used to log messages.Calls
optim_wrapper.update_params(loss)to update model.
- Parameters:
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns:
A
dictof tensor for logging.- Return type:
Dict[str, torch.Tensor]