Source code for clinicadl.utils.seed

import logging
import os
import random
from contextlib import contextmanager
from typing import Generator, Optional

import numpy as np
import torch

from .computational.ddp import get_rank

logger = logging.getLogger(__name__)

MAX_SEED_VALUE = np.iinfo(np.uint32).max
MIN_SEED_VALUE = np.iinfo(np.uint32).min
GLOBAL_SEED = "CLINICADL_GLOBAL_SEED"
DETERMINISTIC = "CLINICADL_DETERMINISTIC"
PYTHON_HASH_SEED = "PYTHONHASHSEED"
CUBLAS_CONFIG = "CUBLAS_WORKSPACE_CONFIG"


def pl_worker_init_function(worker_id: int) -> None:
    """
    To handle seeding with multiprocessing.

    From https://pytorch-lightning.readthedocs.io/en/1.7.7/_modules/pytorch_lightning/utilities/seed.html#pl_worker_init_function.
    """
    # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
    process_seed = torch.initial_seed()
    # back out the base seed so we can use all the bits (https://docs.pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading)
    base_seed = process_seed - worker_id
    ss = np.random.SeedSequence([base_seed, worker_id, get_rank()])
    # use 128 bits (4 x 32-bit words)
    np.random.seed(ss.generate_state(4))
    # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
    torch_ss, stdlib_ss = ss.spawn(2)
    # PyTorch 1.7 and above takes a 64-bit seed
    dtype = np.uint64 if torch.__version__ > "1.7.0" else np.uint32
    torch.manual_seed(torch_ss.generate_state(1, dtype=dtype)[0])
    # use 128 bits expressed as an integer
    stdlib_seed = (
        stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]
    ).sum()
    random.seed(stdlib_seed)


[docs] def seed_everything(seed: Optional[int] = None, deterministic: bool = False) -> None: """ To control reproducibility. It will seed pseudo-random number generators in: PyTorch, Numpy and Python's random module. The seed can be accessed via the environment variable ``"CLINICADL_GLOBAL_SEED"``. Besides, if ``deterministic=True``, PyTorch's operations will be configured in deterministic mode, to the extent possible. In this case, an environment variable ``"CLINICADL_DETERMINISTIC"`` will also be created. .. warning:: ``deterministic=True`` - does not guarantee fully reproducible results; it only ensures determinism within PyTorch’s current limitations; - comes with a cost in computing performances. It is advised to use this parameter only for your final experiments. Parameters ---------- seed : Optional[int], default=None The seed to use. If ``None``, a random seed will be generated. deterministic : bool, default=False Whether to configure PyTorch's operations in deterministic mode. Examples -------- .. code-block:: from clinicadl.utils.seed import seed_everything import torch import numpy as np import random .. code-block:: >>> seed_everything(0) >>> torch.randn(1), np.random.randn(), random.randint(0, 100) (tensor([1.5410]), 1.764052345967664, 49) >>> seed_everything(0) >>> torch.randn(1), np.random.randn(), random.randint(0, 100) (tensor([1.5410]), 1.764052345967664, 49) """ if seed is None: seed = random.randint(MIN_SEED_VALUE, MAX_SEED_VALUE) if not (MIN_SEED_VALUE <= seed <= MAX_SEED_VALUE): raise ValueError( f"Seed must be between {MIN_SEED_VALUE} and {MAX_SEED_VALUE}. Got {seed}" ) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all( seed ) # manual_seed should be ok because one process per GPU os.environ[PYTHON_HASH_SEED] = str(seed) os.environ[GLOBAL_SEED] = str(seed) logger.info("Global seed set to %d", seed) if deterministic: os.environ[CUBLAS_CONFIG] = ":4096:8" torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ[DETERMINISTIC] = "true"
[docs] @contextmanager def seed_everything_context( seed: Optional[int] = None, deterministic: bool = False ) -> Generator[None, None, None]: """ Context manager to control reproducibility. Does the same as :py:func:`seed_everything`, but restore all previous random states when exiting the context. Parameters ---------- seed : Optional[int], default=None The seed to use. If ``None``, a random seed will be generated. deterministic : bool, default=False Whether to configure PyTorch's operations in deterministic mode. Examples -------- .. code-block:: from clinicadl.utils.seed import seed_everything_context import torch import numpy as np import random .. code-block:: >>> with seed_everything_context(0): print(torch.randn(1), np.random.randn(), random.randint(0, 100)) tensor([1.5410]) 1.764052345967664 49 >>> print(torch.randn(1), np.random.randn(), random.randint(0, 100)) tensor([0.8120]) 0.5023488957207493 50 >>> with seed_everything_context(0): print(torch.randn(1), np.random.randn(), random.randint(0, 100)) tensor([1.5410]) 1.764052345967664 49 """ prev_env = { PYTHON_HASH_SEED: os.environ.get(PYTHON_HASH_SEED), GLOBAL_SEED: os.environ.get(GLOBAL_SEED), DETERMINISTIC: os.environ.get(DETERMINISTIC), CUBLAS_CONFIG: os.environ.get(CUBLAS_CONFIG), } py_random_state = random.getstate() np_state = np.random.get_state() torch_cpu_state = torch.get_rng_state() torch_cuda_states = torch.cuda.get_rng_state_all() prev_det_algos = torch.are_deterministic_algorithms_enabled() prev_cudnn_deterministic = torch.backends.cudnn.deterministic prev_cudnn_benchmark = torch.backends.cudnn.benchmark try: seed_everything(seed, deterministic) yield finally: random.setstate(py_random_state) np.random.set_state(np_state) torch.set_rng_state(torch_cpu_state) torch.cuda.set_rng_state_all(torch_cuda_states) torch.use_deterministic_algorithms(prev_det_algos) torch.backends.cudnn.deterministic = prev_cudnn_deterministic torch.backends.cudnn.benchmark = prev_cudnn_benchmark for k, v in prev_env.items(): if v is None: os.environ.pop(k, None) else: os.environ[k] = v