Skip to content

edsnlp.models.pytorch_wrapper

PredT = typing.TypeVar('PredT') module-attribute

PytorchWrapperModule

Bases: torch.nn.Module

Source code in edsnlp/models/pytorch_wrapper.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class PytorchWrapperModule(torch.nn.Module):
    def __init__(
        self,
        input_size: Optional[int] = None,
        n_labels: Optional[int] = None,
    ):
        """
        Pytorch wrapping module for Spacy.
        Models that expect to be wrapped with
        [wrap_pytorch_model][edsnlp.models.pytorch_wrapper.wrap_pytorch_model]
        should inherit from this module.

        Parameters
        ----------
        input_size: int
            Size of the input embeddings
        n_labels: int
            Number of labels predicted by the module
        """
        super().__init__()

        self.cfg = {"n_labels": n_labels, "input_size": input_size}

    @property
    def n_labels(self):
        return self.cfg["n_labels"]

    @property
    def input_size(self):
        return self.cfg["input_size"]

    def load_state_dict(
        self, state_dict: OrderedDict[str, torch.Tensor], strict: bool = True
    ):
        """
        Loads the model inplace from a dumped `state_dict` object

        Parameters
        ----------
        state_dict: OrderedDict[str, torch.Tensor]
        strict: bool
        """
        self.cfg = state_dict.pop("cfg")
        self.initialize()
        super().load_state_dict(state_dict, strict)

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        """
        Loads the model inplace from a dumped `state_dict` object

        Parameters
        ----------
        destination: Any
        prefix: str
        keep_vars: bool

        Returns
        -------
        dict
        """
        state = super().state_dict(destination, prefix, keep_vars)
        state["cfg"] = self.cfg
        return state

    def set_n_labels(self, n_labels):
        """
        Sets the number of labels. To instantiate the linear layer, we need to
        call the `initialize` method.

        Parameters
        ----------
        n_labels: int
            Number of different labels predicted by this module
        """
        self.cfg["n_labels"] = n_labels

    def initialize(self):
        """
        Once the number of labels n_labels are known, this method
        initializes the torch linear layer.
        """
        raise NotImplementedError()

    def forward(
        self,
        embeds: torch.FloatTensor,
        mask: torch.BoolTensor,
        *,
        additional_outputs: typing.Dict[str, Any] = None,
        is_train: bool = False,
        is_predict: bool = False,
    ) -> Optional[torch.FloatTensor]:
        """
        Apply the nested pytorch module to:
        - compute the loss
        - predict the outputs
        non exclusively.
        If outputs are predicted, they are assigned to the `additional_outputs`
        list.

        Parameters
        ----------
        embeds: torch.FloatTensor
            Input embeddings
        mask: torch.BoolTensor
            Input embeddings mask
        additional_outputs: List
            Additional outputs that should not / cannot be back-propped through
            (Thinc treats Pytorch models solely as derivable functions, but the CRF
            that we employ performs the best tag decoding function with Pytorch)
            This list will contain the predicted outputs
        is_train: bool=False
            Are we training the model (defaults to True)
        is_predict: bool=False
            Are we predicting the model (defaults to False)

        Returns
        -------
        Optional[torch.FloatTensor]
            Optional 0d loss (shape = [1]) to train the model
        """
        raise NotImplementedError()

cfg = {'n_labels': n_labels, 'input_size': input_size} instance-attribute

__init__(input_size=None, n_labels=None)

Pytorch wrapping module for Spacy. Models that expect to be wrapped with wrap_pytorch_model should inherit from this module.

PARAMETER DESCRIPTION
input_size

Size of the input embeddings

TYPE: Optional[int] DEFAULT: None

n_labels

Number of labels predicted by the module

TYPE: Optional[int] DEFAULT: None

Source code in edsnlp/models/pytorch_wrapper.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    input_size: Optional[int] = None,
    n_labels: Optional[int] = None,
):
    """
    Pytorch wrapping module for Spacy.
    Models that expect to be wrapped with
    [wrap_pytorch_model][edsnlp.models.pytorch_wrapper.wrap_pytorch_model]
    should inherit from this module.

    Parameters
    ----------
    input_size: int
        Size of the input embeddings
    n_labels: int
        Number of labels predicted by the module
    """
    super().__init__()

    self.cfg = {"n_labels": n_labels, "input_size": input_size}

n_labels()

Source code in edsnlp/models/pytorch_wrapper.py
44
45
46
@property
def n_labels(self):
    return self.cfg["n_labels"]

input_size()

Source code in edsnlp/models/pytorch_wrapper.py
48
49
50
@property
def input_size(self):
    return self.cfg["input_size"]

load_state_dict(state_dict, strict=True)

Loads the model inplace from a dumped state_dict object

PARAMETER DESCRIPTION
state_dict

TYPE: OrderedDict[str, torch.Tensor]

strict

TYPE: bool DEFAULT: True

Source code in edsnlp/models/pytorch_wrapper.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def load_state_dict(
    self, state_dict: OrderedDict[str, torch.Tensor], strict: bool = True
):
    """
    Loads the model inplace from a dumped `state_dict` object

    Parameters
    ----------
    state_dict: OrderedDict[str, torch.Tensor]
    strict: bool
    """
    self.cfg = state_dict.pop("cfg")
    self.initialize()
    super().load_state_dict(state_dict, strict)

state_dict(destination=None, prefix='', keep_vars=False)

Loads the model inplace from a dumped state_dict object

PARAMETER DESCRIPTION
destination

DEFAULT: None

prefix

DEFAULT: ''

keep_vars

DEFAULT: False

RETURNS DESCRIPTION
dict
Source code in edsnlp/models/pytorch_wrapper.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def state_dict(self, destination=None, prefix="", keep_vars=False):
    """
    Loads the model inplace from a dumped `state_dict` object

    Parameters
    ----------
    destination: Any
    prefix: str
    keep_vars: bool

    Returns
    -------
    dict
    """
    state = super().state_dict(destination, prefix, keep_vars)
    state["cfg"] = self.cfg
    return state

set_n_labels(n_labels)

Sets the number of labels. To instantiate the linear layer, we need to call the initialize method.

PARAMETER DESCRIPTION
n_labels

Number of different labels predicted by this module

Source code in edsnlp/models/pytorch_wrapper.py
85
86
87
88
89
90
91
92
93
94
95
def set_n_labels(self, n_labels):
    """
    Sets the number of labels. To instantiate the linear layer, we need to
    call the `initialize` method.

    Parameters
    ----------
    n_labels: int
        Number of different labels predicted by this module
    """
    self.cfg["n_labels"] = n_labels

initialize()

Once the number of labels n_labels are known, this method initializes the torch linear layer.

Source code in edsnlp/models/pytorch_wrapper.py
 97
 98
 99
100
101
102
def initialize(self):
    """
    Once the number of labels n_labels are known, this method
    initializes the torch linear layer.
    """
    raise NotImplementedError()

forward(embeds, mask, *, additional_outputs=None, is_train=False, is_predict=False)

Apply the nested pytorch module to: - compute the loss - predict the outputs non exclusively. If outputs are predicted, they are assigned to the additional_outputs list.

PARAMETER DESCRIPTION
embeds

Input embeddings

TYPE: torch.FloatTensor

mask

Input embeddings mask

TYPE: torch.BoolTensor

additional_outputs

Additional outputs that should not / cannot be back-propped through (Thinc treats Pytorch models solely as derivable functions, but the CRF that we employ performs the best tag decoding function with Pytorch) This list will contain the predicted outputs

TYPE: typing.Dict[str, Any] DEFAULT: None

is_train

Are we training the model (defaults to True)

TYPE: bool DEFAULT: False

is_predict

Are we predicting the model (defaults to False)

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Optional[torch.FloatTensor]

Optional 0d loss (shape = [1]) to train the model

Source code in edsnlp/models/pytorch_wrapper.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def forward(
    self,
    embeds: torch.FloatTensor,
    mask: torch.BoolTensor,
    *,
    additional_outputs: typing.Dict[str, Any] = None,
    is_train: bool = False,
    is_predict: bool = False,
) -> Optional[torch.FloatTensor]:
    """
    Apply the nested pytorch module to:
    - compute the loss
    - predict the outputs
    non exclusively.
    If outputs are predicted, they are assigned to the `additional_outputs`
    list.

    Parameters
    ----------
    embeds: torch.FloatTensor
        Input embeddings
    mask: torch.BoolTensor
        Input embeddings mask
    additional_outputs: List
        Additional outputs that should not / cannot be back-propped through
        (Thinc treats Pytorch models solely as derivable functions, but the CRF
        that we employ performs the best tag decoding function with Pytorch)
        This list will contain the predicted outputs
    is_train: bool=False
        Are we training the model (defaults to True)
    is_predict: bool=False
        Are we predicting the model (defaults to False)

    Returns
    -------
    Optional[torch.FloatTensor]
        Optional 0d loss (shape = [1]) to train the model
    """
    raise NotImplementedError()

custom_xp2torch(model, X)

Source code in edsnlp/models/pytorch_wrapper.py
145
146
147
148
149
150
151
152
153
def custom_xp2torch(model, X):
    main = xp2torch(X[0], requires_grad=True)
    rest = convert_recursive(is_xp_array, lambda x: xp2torch(x), X[1:])

    def reverse_conversion(dXtorch):
        dX = torch2xp(dXtorch.args[0])
        return dX

    return (main, *rest), reverse_conversion

pytorch_forward(model, X, is_train=False)

Run the stacked CRF pytorch model to train / run a nested NER model

PARAMETER DESCRIPTION
model

TYPE: Model

X

TYPE: Tuple[Iterable[Doc], PredT, bool]

is_train

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Tuple[Tuple[Floats1d, PredictionT], Callable[Floats1d, Any]]
Source code in edsnlp/models/pytorch_wrapper.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def pytorch_forward(
    model: Model,
    X: Tuple[Iterable[Doc], PredT, bool],
    is_train: bool = False,
) -> Tuple[Tuple[Floats1d, PredT], Callable[[Floats1d], Any]]:
    """
    Run the stacked CRF pytorch model to train / run a nested NER model

    Parameters
    ----------
    model: Model
    X: Tuple[Iterable[Doc], PredictionT, bool]
    is_train: bool

    Returns
    -------
    Tuple[Tuple[Floats1d, PredictionT], Callable[Floats1d, Any]]
    """
    [docs, *rest_X, is_predict] = X
    encoder: Model[List[Doc], List[Floats2d]] = model.get_ref("encoder")
    embeds_list, bp_embeds = encoder(docs, is_train=is_train)
    embeds = model.ops.pad(embeds_list)  # pad embeds

    ##################################################
    # Prepare the torch nested ner crf module inputs #
    ##################################################
    additional_outputs = {}
    # Convert input from numpy/cupy to torch
    (torch_embeds, *torch_rest), get_d_embeds = custom_xp2torch(
        model, (embeds, *rest_X)
    )
    # Prepare token mask from docs' lengths
    torch_mask = (
        torch.arange(embeds.shape[1], device=torch_embeds.device)
        < torch.tensor([d.shape[0] for d in embeds_list], device=torch_embeds.device)[
            :, None
        ]
    )

    #################
    # Run the model #
    #################
    loss_torch, torch_backprop = model.shims[0](
        ArgsKwargs(
            (torch_embeds, torch_mask, *torch_rest),
            {
                "additional_outputs": additional_outputs,
                "is_train": is_train,
                "is_predict": is_predict,
            },
        ),
        is_train,
    )

    ####################################
    # Postprocess the module's outputs #
    ####################################
    loss = torch2xp(loss_torch) if loss_torch is not None else None
    additional_outputs = convert_recursive(is_torch_array, torch2xp, additional_outputs)

    def backprop(d_loss: Floats1d) -> Any:
        d_loss_torch = ArgsKwargs(
            args=((loss_torch,),), kwargs={"grad_tensors": xp2torch(d_loss)}
        )
        d_embeds_torch = torch_backprop(d_loss_torch)
        d_embeds = get_d_embeds(d_embeds_torch)
        d_embeds_list = [
            d_padded_row[: len(d_item)]
            for d_item, d_padded_row in zip(embeds_list, d_embeds)
        ]
        d_docs = bp_embeds(d_embeds_list)
        return d_docs

    return (loss, additional_outputs), backprop

instance_init(model, X=None, Y=None)

Initializes the model by setting the input size of the model layers and the number of predicted labels

PARAMETER DESCRIPTION
model

Nested NER thinc model

TYPE: Model

X

list of documents on which we apply the encoder layer

TYPE: List[Doc] DEFAULT: None

Y

Unused gold spans

TYPE: Ints2d DEFAULT: None

Source code in edsnlp/models/pytorch_wrapper.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def instance_init(model: Model, X: List[Doc] = None, Y: Ints2d = None) -> Model:
    """
    Initializes the model by setting the input size of the model layers and the number
    of predicted labels

    Parameters
    ----------
    model: Model
        Nested NER thinc model
    X: List[Doc]
        list of documents on which we apply the encoder layer
    Y: Ints2d
        Unused gold spans

    Returns
    -------

    """
    encoder = model.get_ref("encoder")
    if X is not None:
        encoder.initialize(X)

    pt_model = model.attrs["pt_model"]
    pt_model.cfg["input_size"] = encoder.get_dim("nO")
    pt_model.initialize()
    pt_model.to(get_torch_default_device())
    model.set_dim("nI", pt_model.input_size)

    return model

wrap_pytorch_model(encoder, pt_model)

Chain and wraps a spaCy/Thinc encoder model (like a tok2vec) and a pytorch model. The loss should be computed directly in the Pytorch module and Categorical predictions are supported

PARAMETER DESCRIPTION
encoder

The Thinc document token embedding layer

TYPE: Model[List[Doc], List[Floats2d]]

pt_model

The Pytorch model

TYPE: PytorchWrapperModule

RETURNS DESCRIPTION
Tuple[Iterable[Doc], Optional[PredT], Optional[bool]],

inputs (docs, gold, *rest, is_predict)

Tuple[Floats1d, PredT],

outputs (loss, *additional_outputs)

Source code in edsnlp/models/pytorch_wrapper.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def wrap_pytorch_model(
    encoder: Model[List[Doc], List[Floats2d]],
    pt_model: PytorchWrapperModule,
) -> Model[
    Tuple[Iterable[Doc], Optional[PredT], Optional[bool]],
    Tuple[Floats1d, PredT],
]:
    """
    Chain and wraps a spaCy/Thinc encoder model (like a tok2vec) and a pytorch model.
    The loss should be computed directly in the Pytorch module and Categorical
    predictions are supported

    Parameters
    ----------
    encoder: Model[List[Doc], List[Floats2d]]
        The Thinc document token embedding layer
    pt_model: PytorchWrapperModule
        The Pytorch model

    Returns
    -------
        Tuple[Iterable[Doc], Optional[PredT], Optional[bool]],
        # inputs (docs, gold, *rest, is_predict)
        Tuple[Floats1d, PredT],
        # outputs (loss, *additional_outputs)
    """
    return Model(
        "pytorch",
        pytorch_forward,
        attrs={
            "set_n_labels": pt_model.set_n_labels,
            "pt_model": pt_model,
        },
        layers=[encoder],
        shims=[PyTorchShim(pt_model)],
        refs={"encoder": encoder},
        dims={"nI": None, "nO": None},
        init=instance_init,
    )