• Docs >
  • Distributed Checkpoint - torch.distributed.checkpoint
Shortcuts

Distributed Checkpoint - torch.distributed.checkpoint

Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. It handles load-time resharding which enables saving in one cluster topology and loading into another.

DCP is different than torch.save and torch.load in a few significant ways:

  • It produces multiple files per checkpoint, with at least one per rank.

  • It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.

The entrypoints to load and save a checkpoint are the following:

torch.distributed.checkpoint.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]

Load a distributed state_dict in SPMD style.

Each rank will try to read the least amount of data necessary to fullfill the requested state_dict. When loading ShardedTensor instances, each rank only reads data for their local shards.

Warning

All tensors in state_dict must be allocated on their destination device prior to calling this function.

All non-tensor data is loaded using torch.load() and modified in place on state_dict.

Warning

Users must call load_state_dict on the root module to ensure load pos-processing and non-tensor data properly propagates.

Parameters
  • state_dict (Dict[str, Any]) – The state_dict to load. Note that this state dict will updated in place.

  • storage_reader (StorageReader) – StorageReader used to load data from.

  • process_group (ProcessGroup) – ProcessGroup to be used for cross-rank synchronization.

  • coordinator_rank (int) – Rank to use to coordinate the checkpoint. rank0 is used by default.

  • no_dist (bool) – If True, distributed checkpoint will not load in SPMD style. (Default: False)

Returns

None.

Return type

None

Examples
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

Note

load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().

torch.distributed.checkpoint.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]

Save a distributed model in SPMD style.

This function is different from torch.save() as it handles ShardedTensor by having each rank only save their local shards.

Warning

There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts.

Warning

If using the process_group argument, make sure that only its ranks call save_state_dict and that all data in state_dict belong to it.

Note

When saving checkpoint for FSDP’s ShardingStrategy.HYBRID_SHARD, only one of the shard_group should be calling save_state_dict and the corresponding process group needs to be passed in.

Note

This function can be used to save a state_dict without having a process group initialized by passing no_dist=True.

Parameters
  • state_dict (Dict[str, Any]) – The state_dict to save.

  • storage_writer (StorageWriter) – Instance of StorageWrite use to perform writes.

  • process_group (ProcessGroup) – ProcessGroup to be used for cross-rank synchronization.

  • coordinator_rank (int) – Rank to use to coordinate the checkpoint. rank0 is used by default.

  • no_dist (bool) – If True, distributed checkpoint will not save in SPMD style. (Default: False)

Returns

Metadata object for the saved checkpoint.

Return type

Metadata

Example

>>> my_model = MyModule()
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

Note

save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().

This example shows how to use Pytorch Distributed Checkpoint to save a FSDP model.

The following types define the IO interface used during checkpoint:

class torch.distributed.checkpoint.StorageReader[source]

Interface used by load_state_dict to read from storage.

One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.

A subclass should expected the following sequence of calls by load_state_dict:

  1. (all ranks) read_metadata()

  2. (all ranks) set_up_storage_reader()

  3. (all ranks) prepare_local_plan()

  4. (coordinator) prepare_global_plan()

  5. (all ranks) read_data()

abstract prepare_global_plan(plans)[source]

Perform centralized planning of storage loading.

This method is only called on the coordinator instance.

While this method can produce a completely different plan, the preferred way is to store storage specific data in LoadPlan::storage_data.

Parameters

plans (List[LoadPlan]) – A list of LoadPlan instances, one for each rank.

Returns

A list of transformed LoadPlan after storage global planning

Return type

List[LoadPlan]

abstract prepare_local_plan(plan)[source]

Perform storage-specific local planning.

While this method can produce a completely different plan, the recommended way is to store storage specific data in LoadPlan::storage_data.

Parameters

plan (LoadPlan) – The local plan from the LoadPlan in use.

Returns

A transformed LoadPlan after storage local planning

Return type

LoadPlan

abstract read_data(plan, planner)[source]

Read all items from plan using planner to resolve the data.

A subclass should call LoadPlanner::load_bytes to deserialize a BytesIO object into the right place.

A subclass should call LoadPlanner::resolve_tensor to get access to the tensors that in should load data into.

It’s the StorageLayer responsibility to properly schedule any cross device copies required.

Parameters
  • plan (LoadPlan) – The local plan to execute on

  • planner (LoadPlanner) – The planner object to use to resolve items.

Returns

A future that completes once all reads are finished.

Return type

Future[None]

abstract read_metadata()[source]

Read the checkpoint metadata.

Returns

The metadata object associated with the checkpoint being loaded.

Return type

Metadata

abstract set_up_storage_reader(metadata, is_coordinator)[source]

Initialize this instance.

Parameters
  • metadata (Metadata) – The metadata schema to use.

  • is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint.

class torch.distributed.checkpoint.StorageWriter[source]

Interface used by save_state_dict to write to storage.

One StorageWriter instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.

A subclass should expect the following sequence of calls.

  1. (all ranks) set_up_storage_writer()

  2. (all ranks) prepare_local_plan()

  3. (coordinator) prepare_global_plan()

  4. (all ranks) write_data()

  5. (coordinator) finish()

abstract finish(metadata, results)[source]

Write the metadata and marks the current checkpoint as successful.

The actual format/schema used for serializing metadata is an implementation detail. The only requirement is that it’s recoverable in to the same object graph.

Parameters
  • metadata (Metadata) – metadata for the new checkpoint

  • results (List[List[WriteResult]]) – A list of WriteResults from all ranks.

Returns

None

Return type

None

abstract prepare_global_plan(plans)[source]

Perform centralized planning of storage.

This method is only called on the coordinator instance.

While this method can produce a completely different plan, the preferred way is to store storage specific data in SavePlan::storage_data.

Parameters

plans (List[SavePlan]) – A list of SavePlan instances, one for each rank.

Returns

A list of transformed SavePlan after storage global planning

Return type

List[SavePlan]

abstract prepare_local_plan(plan)[source]

Perform storage-specific local planning.

While this method can produce a completely different plan, the recommended way is to store storage specific data in SavePlan::storage_data.

Parameters

plan (SavePlan) – The local plan from the SavePlanner in use.

Returns

A transformed SavePlan after storage local planning

Return type

SavePlan

abstract set_up_storage_writer(is_coordinator)[source]

Initialize this instance.

Parameters

is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint.

abstract write_data(plan, planner)[source]

Write all items from plan using planner to resolve the data.

A subclass should call SavePlanner::resolve_data on each item from the plan to get access to the underlying object to write.

Subclasses should lazily call resolve_data as it can allocate memory. In case of tensors, make following assumptions:

  • They might be on any device, including not matching the one on WriteItem::tensor_data

  • They might be views or not contiguous. Only the projection needs to be saved.

Parameters
  • plan (SavePlan) – The save plan to execute.

  • planner (SavePlanner) – Planner object to be used to resolve items to data.

Returns

A future that completes to a list of WriteResult

Return type

Future[List[WriteResult]]

The following types define the planner interface used during checkpoint:

class torch.distributed.checkpoint.LoadPlanner[source]

Abstract class defining the protocol used by load_state_dict to plan the load process.

LoadPlanner are stateful objects that can be used to customize the whole load process.

LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.

A planner subclass can expect the following sequence of calls during load_state_dict:

  1. set_up_planner - called on all ranks.

    Signals the start of loading a checkpoint.

  2. create_local_plan - called on all ranks.

    Process the state_dict and produces a LoadPlan that will be sent for global planning.

  3. create_global_plan - called on the coordinator rank only.

    Takes the LoadPlan from all ranks and make any global decision.

  4. load_bytes - called multiple times on each rank

    This is called once per non-tensor value in state_dict.

  5. resolve_tensor and commit_tensor - called multiple times on each rank

    They are called in pair for each Tensor value in state_dict.

Users are recommended to extend DefaultLoadPlanner instead of this interface directly as most changes can be expressed by changes in a single method.

There are two usual patterns of extension:

Rewriting state_dict. This is the simplest way to extend the load process as it doesn’t requite understanding the intrincacies of how LoadPlan works. We need to keep a reference to the original state_dict as load happens in place so we need to be able to perform it in place

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(self, state_dict, metadata, is_coordinator):
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)

Modifying resolve_tensor and commit_tensor to handle load time transformation.

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
abstract commit_tensor(read_item, tensor)[source]

Call once the StorageReader finished loading data into tensor.

The provided tensor is the same one returned by the call to resolve_tensor. This method is only needed if this LoadPlanner needs to post process tensor prior to copying it back to the one in the state_dict.

The contents of tensor will follow its device synchronization model.

abstract create_global_plan(global_plan)[source]

Compute the global load plan and return plans for each rank.

. N.B. This is called on the coordinator rank only

Return type

List[LoadPlan]

abstract create_local_plan()[source]

Create a LoadPlan based on state_dict and metadata provided by set_up_planner.

. N.B. This is called on every rank.

Return type

LoadPlan

abstract finish_plan(central_plan)[source]

Accept the plan from coordinator and return final LoadPlan.

Return type

LoadPlan

abstract load_bytes(read_item, value)[source]

Load the item described by read_item``and ``value.

This method is expected to modify in-place the underlying state_dict.

The contents of value are defined by the SavePlanner used to produce the checkpoint being loaded.

abstract resolve_tensor(read_item)[source]

Return the tensor described by read_item to be used by the StorageReader to load read_item.

The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. If, for any reason, that’s not possible, the planner can use the commit_tensor method to copy the data back to the one in state_dict.

Return type

Tensor

abstract set_up_planner(state_dict, metadata, is_coordinator)[source]

Initialize this instance to load data into state_dict.

. N.B. This is called on every rank.

class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]
class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]
class torch.distributed.checkpoint.SavePlanner[source]

Abstract class defining the protocol used by save_state_dict to plan the save process.

SavePlanners are stateful objects that can be used to customize the whole save process.

SavePlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.

A planner subclass can expect the following sequence of calls during save_state_dict:

  1. set_up_planner - called on all ranks.

    Signals the start of a checkpoint save.

  2. create_local_plan - called on all ranks.

    Process the state_dict and produces a SavePlan that will be sent for global planning.

  3. create_global_plan - called on the coordinator rank only.

    Takes the SavePlan from all ranks and make any global decision.

  4. finish_plan - called on all ranks.

    This gives each rank a chance to adjust to global planning decisions.

  5. resolve_data - called multiple times on each rank

    Lookups a value on the state_dict for the storage layer to write.

Users are recommended to extend DefaultSavePlanner instead of this interface directly as most changes can be expressed by changes in a single method.

There are 3 usual patterns of extension:

Rewriting state_dict. This is the simplest way to extend the save process as it doesn’t requite understanding the intrincacies of how SavePlan works:

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(self, state_dict, is_coordinator):
>>>         # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)

Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

Using the global planning step to make central decisions that can’t be made individually by each rank

>>> from itertools import islice
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>>     # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         def chunk(it, size):
>>>             it = iter(it)
>>>         return list(iter(lambda: tuple(islice(it, size)), ()))
>>>         all_plans = [
>>>             replace(plan, items=items) for plan, items in
>>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
>>>         ]
>>>         return super().create_global_plan(all_plans)

Finally, some planners need to save additional metadata in the checkpoint, this is accomplished by having each rank contribute their data items in the local plan and the global planner aggregate them:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
abstract create_global_plan(all_plans)[source]

Compute the global checkpoint plan and return the local plan of each rank.

This is called on the coordinator rank only.

Return type

Tuple[List[SavePlan], Metadata]

abstract create_local_plan()[source]

Compute the save plan for the current rank.

This will be aggregated and passed to create_global_plan. Planner specific data can be passed through SavePlan::planner_data.

This is called on all ranks.

Return type

SavePlan

abstract finish_plan(new_plan)[source]

Merge the plan created by create_local_plan and the result of create_global_plan.

This is called on all ranks.

Return type

SavePlan

abstract resolve_data(write_item)[source]

Transform and prepare write_item from state_dict for storage, ensuring idempotency and thread-safety.

Lookup the object associated with write_item in state_dict and apply any transformation (such as serialization) prior to the storage layer consuming it.

Called on each rank multiple times, at least once per WriteItem in the final SavePlan.

This method should be idempotent and thread-save. StorageWriter implementations are free to call it as frequently as they need.

Any transformation that allocates memory should be lazily done when his method is called in order to reduce peak memory required by checkpointing.

When returning tensors, they can be on any device or format, they can be views too. It’s the storage layer responsibility to figure out how to save them.

Return type

Union[Tensor, BytesIO]

abstract set_up_planner(state_dict, is_coordinator)[source]

Initialize this planner to save state_dict.

Implementations should save those values as they won’t be provided lated in the save process.

This is called on all ranks.

class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source]
class torch.distributed.checkpoint.WriteItem(index: torch.distributed.checkpoint.metadata.MetadataIndex, type: torch.distributed.checkpoint.planner.WriteItemType, tensor_data: Union[torch.distributed.checkpoint.planner.TensorWriteData, NoneType] = None)[source]

We provide a filesystem based storage layer:

class torch.distributed.checkpoint.FileSystemReader(path)[source]
class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000)[source]

Basic implementation of StorageWriter using file IO.

This implementation makes the following assumptions and simplifications:

  • The checkpoint path is an empty or non-existing directory.

  • File creation is atomic

The checkpoint consist of one file per write request plus a .metadata file with the serialized metadata.

We provide default implementations of LoadPlanner and SavePlanner that can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.

class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=True)[source]
lookup_object(index)[source]

Extension from the planner interface to make it easy to extend the default planner.

Return type

Any

transform_object(write_item, object)[source]

Extension from the planner interface to make it easy to extend the default planner.

class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[source]

DefaultLoadPlanner that adds multiple features on top of LoadPlanner.

In particular it adds the following:

flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode

lookup_tensor(index)[source]

Extension from the planner interface to make it easy to extend the default planner.

Return type

Tensor

transform_tensor(read_item, tensor)[source]

Extension from the planner interface to make it easy to extend the default planner.

We provide a set of APIs to help users do get and set state_dict easily. This is an experimental feature and is subject to change.

torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]

Return the model state_dict and optimizers state_dict.

get_state_dict can process any module that is parallelized by PyTorch FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any combination of these parallelisms. The main functions of get_state_dict are: 1.) returning a model and optimizer state_dict that can be resharded with a different number of trainers and/or different parallelisms. 2.) hiding the parallelism-specific state_dict APIs. Users don’t have to call these APIs. 3.) sanity checking the result state_dict.

The keys of the result state dictionary are the canonical FQNs (Fully Qualified Names). A canonical FQN refers to the FQN based on a parameter’s position in an nn.Module hierarchy. More specifically, a canonical FQN to a parameter is the FQN returned by module.named_parameters() or module.named_buffers() when the module is not distributed by any parallelisms. Since the optimizer internally uses parameter IDs to represent a parameter, there will be a conversion from the parameter IDs to the canonical FQNs when calling this API.

get_state_dict can also process a module that is not parallelized. In such a case, get_state_dict only performs one function – converting the optimizer parameter IDs to the canonical FQNs.

Example

import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.checkpoint.state_dict import get_state_dict

fsdp_model = FSDP(copy.deepcopy(model)) fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) ddp_model = DDP(copy.deepcopy(model)) ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)

ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)

# if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), # the asserts will fail. assert ddp_state_dict == fsdp_state_dict assert ddp_optim_state == fsdp_optim_state_dict

Parameters
  • model (nn.Module) – the nn.Module to the model.

  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model.

  • submodules (Optional[Set[Module]]) – Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules.

  • options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.

Returns

Tuple that contain model state_dict and optimizer state_dict.

Return type

Tuple[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]], Dict[str, Union[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]], List[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]]]]]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]

Return the model state_dict of model.

See get_state_dict for the detail usage.

Parameters
  • model (nn.Module) – the nn.Module to the model.

  • submodules (Optional[Set[Module]]) – Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules.

  • options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.

Returns

The state_dict for model.

Return type

Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]

Return the combined state_dict for optimizers.

See get_state_dict for the detail usage.

Parameters
  • model (nn.Module) – the nn.Module to the model.

  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model.

  • submodules (Optional[Set[Module]]) – Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules.

  • options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.

Returns

The state_dict for optimizers.

Return type

Dict[str, Union[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]], List[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]]]]

torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]

Load the model state_dict and optimizers state_dict.

The counterpart of get_state_dict to set the state_dict to the model and optimizers. The given model_state_dict and optim_state_dict do not have to be returned by get_state_dict but must meet the following requirements: 1) all FQNs are canonical FQNs as defined in get_state_dict, 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, 3) optimizer state_dict cannot contain the parameter IDs; the keys should be the canonical FQNs.

Parameters
Returns

  • missing_keys is a list of str containing the missing keys of the model state_dict.

  • unexpected_keys is a list of str containing the unexpected keys of the model state_dict.

Return type

NamedTuple with missing_keys and unexpected_keys fields

torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]

Load the model state_dict.

The counterpart of get_model_state_dict to set the state_dict to the model. See set_state_dict for the detail usage.

Parameters
Returns

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type

NamedTuple with missing_keys and unexpected_keys fields

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, *, optim_state_dict, options=None)[source]

Load the optimizers state_dict.

The counterpart of get_optimizer_state_dict to set the state_dict to the optimizers. See set_state_dict for the detail usage.

Parameters
Returns

None

Return type

None

class torch.distributed.checkpoint.state_dict.StateDictOptions(fsdp_state_dict_type=StateDictType.SHARDED_STATE_DICT, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True)[source]

This dataclass specifies how get_state_dict/set_state_dict will work.

  • fsdp_state_dict_type: if the model contains FSDP sharded submodules, what FSDP state_dict type should be used. The default value is SHARDED_STATE_DICT.

  • ignore_frozen_params: if the value is True, the returned state_dict won’t contain any frozen parameters – the requires_grad is False. The default value is False.

  • keep_submodule_prefixes: when submodules is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule is module.pretrain and the full FQN of the parameter is pretrain.layer1.weight of the param. When this option is True, the parameter’s key in the returned state_dict will be pretrain.layer1.weight. If the options is False, the key will be layer1.weight. Note that if keep_submodule_prefixes is False, there may be conflicted FQNs, hence there should be only one submodule in submodules.

  • strict: the strict option when set_state_dict calls model.load_state_dict(). The default value is False.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources