Skip to content

confit.utils.random

set_seed

Source code in confit/utils/random.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class set_seed:
    def __init__(self, seed, cuda: Optional[bool] = None):
        """
        Set seed values for random generators.
        If used as a context, restore the random state
        used before entering the context.

        Parameters
        ----------
        seed: int
            Value used as a seed.
        cuda: bool
            Saves the cuda random states too
        """
        # if seed is True:
        #     seed = random.randint(1, 2**16)
        seed = random.randint(1, 2**16) if seed is True else seed
        self.state = get_random_generator_state(cuda)
        if seed is not None:
            random.seed(seed)
            try:
                import torch

                torch.manual_seed(seed)
                if cuda or (
                    cuda is None and torch.cuda.is_available()
                ):  # pragma: no cover
                    torch.cuda.manual_seed(seed)
                    torch.cuda.manual_seed_all(seed)
            except ImportError:  # pragma: no cover
                pass
            try:
                import numpy

                numpy.random.seed(seed)
            except ImportError:  # pragma: no cover
                pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        set_random_generator_state(self.state)

__init__(seed, cuda=None)

Set seed values for random generators. If used as a context, restore the random state used before entering the context.

PARAMETER DESCRIPTION
seed

Value used as a seed.

cuda

Saves the cuda random states too

TYPE: Optional[bool] DEFAULT: None

Source code in confit/utils/random.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(self, seed, cuda: Optional[bool] = None):
    """
    Set seed values for random generators.
    If used as a context, restore the random state
    used before entering the context.

    Parameters
    ----------
    seed: int
        Value used as a seed.
    cuda: bool
        Saves the cuda random states too
    """
    # if seed is True:
    #     seed = random.randint(1, 2**16)
    seed = random.randint(1, 2**16) if seed is True else seed
    self.state = get_random_generator_state(cuda)
    if seed is not None:
        random.seed(seed)
        try:
            import torch

            torch.manual_seed(seed)
            if cuda or (
                cuda is None and torch.cuda.is_available()
            ):  # pragma: no cover
                torch.cuda.manual_seed(seed)
                torch.cuda.manual_seed_all(seed)
        except ImportError:  # pragma: no cover
            pass
        try:
            import numpy

            numpy.random.seed(seed)
        except ImportError:  # pragma: no cover
            pass

get_random_generator_state(cuda=None)

Get the torch, numpy and random random generator state.

PARAMETER DESCRIPTION
cuda

Saves the cuda random states too

DEFAULT: None

RETURNS DESCRIPTION
RandomGeneratorState
Source code in confit/utils/random.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_random_generator_state(cuda=None):
    """
    Get the `torch`, `numpy` and `random` random generator state.
    Parameters
    ----------
    cuda: bool
        Saves the cuda random states too

    Returns
    -------
    RandomGeneratorState
    """
    torch_state = torch_cuda_state = numpy_state = None
    random_state = random.getstate()
    try:
        import torch

        torch_state = torch.random.get_rng_state()
        if cuda or (cuda is None and torch.cuda.is_available()):  # pragma: no cover
            torch_cuda_state = torch.cuda.get_rng_state_all()

    except ImportError:  # pragma: no cover
        pass
    try:
        import numpy

        numpy_state = numpy.random.get_state()
    except ImportError:  # pragma: no cover
        pass
    return RandomGeneratorState(
        random_state,
        torch_state,
        numpy_state,
        torch_cuda_state,
    )

set_random_generator_state(state)

Set the torch, numpy and random random generator state.

PARAMETER DESCRIPTION
state

Source code in confit/utils/random.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def set_random_generator_state(state):
    """
    Set the `torch`, `numpy` and `random` random generator state.
    Parameters
    ----------
    state: RandomGeneratorState
    """
    random.setstate(state.random)
    if state.torch is not None:
        import torch

        torch.random.set_rng_state(state.torch)
        if (
            state.torch_cuda is not None
            and torch.cuda.is_available()
            and len(state.torch_cuda) == torch.cuda.device_count()
        ):  # pragma: no cover
            torch.cuda.set_rng_state_all(state.torch_cuda)
    if state.numpy is not None:
        import numpy

        numpy.random.set_state(state.numpy)