Skip to content

edsnlp.models.stack_crf_ner

CRFMode

Bases: str, Enum

Source code in edsnlp/models/stack_crf_ner.py
15
16
17
18
class CRFMode(str, Enum):
    independent = "independent"
    joint = "joint"
    marginal = "marginal"

independent = 'independent' class-attribute

joint = 'joint' class-attribute

marginal = 'marginal' class-attribute

StackedCRFNERModule

Bases: PytorchWrapperModule

Source code in edsnlp/models/stack_crf_ner.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class StackedCRFNERModule(PytorchWrapperModule):
    def __init__(
        self,
        input_size: Optional[int] = None,
        n_labels: Optional[int] = None,
        mode: CRFMode = CRFMode.joint,
    ):
        """
        Nested NER CRF module

        Parameters
        ----------
        input_size: int
            Size of the input embeddings
        n_labels: int
            Number of labels predicted by the module
        mode: CRFMode
            Loss mode of the CRF
        """
        super().__init__(input_size, n_labels)

        self.cfg["mode"] = mode

        assert mode in (CRFMode.independent, CRFMode.joint, CRFMode.marginal)
        self.crf = MultiLabelBIOULDecoder(1, learnable_transitions=False)

        self.classifier = None

    def initialize(self):
        """
        Once the number of labels n_labels are known, this method
        initializes the torch linear layer.
        """
        num_tags = self.n_labels * self.crf.num_tags
        self.classifier = torch.nn.Linear(self.input_size, num_tags)

    def forward(
        self,
        embeds: torch.FloatTensor,
        mask: torch.BoolTensor,
        spans: Optional[torch.LongTensor] = None,
        additional_outputs: Dict[str, Any] = None,
        is_train: bool = False,
        is_predict: bool = False,
    ) -> Optional[torch.FloatTensor]:
        """
        Apply the nested ner module to the document embeddings to:
        - compute the loss
        - predict the spans
        non exclusively.
        If spans are predicted, they are assigned to the `additional_outputs`
        dictionary.

        Parameters
        ----------
        embeds: torch.FloatTensor
            Token embeddings to predict the tags from
        mask: torch.BoolTensor
            Mask of the sequences
        spans: Optional[torch.LongTensor]
            2d tensor of n_spans * (doc_idx, label_idx, begin, end)
        additional_outputs: Dict[str, Any]
            Additional outputs that should not / cannot be back-propped through
            (Thinc treats Pytorch models solely as derivable functions, but the CRF
            that we employ performs the best tag decoding function with Pytorch)
            This dict will contain the predicted 2d tensor of spans
        is_train: bool=False
            Are we training the model (defaults to True)
        is_predict: bool=False
            Are we predicting the model (defaults to False)

        Returns
        -------
        Optional[torch.FloatTensor]
            Optional 0d loss (shape = [1]) to train the model
        """
        n_samples, n_tokens = embeds.shape[:2]
        logits = self.classifier(embeds)
        crf_logits = flatten_dim(
            logits.view(n_samples, n_tokens, self.n_labels, self.crf.num_tags).permute(
                0, 2, 1, 3
            ),
            dim=0,
        )
        crf_mask = repeat(mask, self.n_labels, 0)
        loss = None
        if is_train:
            tags = self.crf.spans_to_tags(
                spans, n_samples=n_samples, n_tokens=n_tokens, n_labels=self.n_labels
            )
            crf_target = flatten_dim(
                torch.nn.functional.one_hot(tags, 5).bool()
                if len(tags.shape) == 3
                else tags,
                dim=0,
            )
            if self.cfg["mode"] == CRFMode.joint:
                loss = self.crf(
                    crf_logits,
                    crf_mask,
                    crf_target,
                )
            elif self.cfg["mode"] == CRFMode.independent:
                loss = (
                    -crf_logits.log_softmax(-1)
                    .masked_fill(~crf_target, IMPOSSIBLE)
                    .logsumexp(-1)[crf_mask]
                    .sum()
                )
            elif self.cfg["mode"] == CRFMode.marginal:
                crf_logits = self.crf.marginal(
                    crf_logits,
                    crf_mask,
                )
                loss = (
                    -crf_logits.log_softmax(-1)
                    .masked_fill(~crf_target, IMPOSSIBLE)
                    .logsumexp(-1)[crf_mask]
                    .sum()
                )
            if (loss > -IMPOSSIBLE).any():
                logger.warning(
                    "You likely have an impossible transition in your "
                    "training data NER tags, skipping this batch."
                )
                loss = torch.zeros(1, dtype=torch.float, device=embeds.device)
            loss = loss.sum().unsqueeze(0) / 100.0
        if is_predict:
            pred_tags = self.crf.decode(crf_logits, crf_mask).reshape(
                n_samples, self.n_labels, n_tokens
            )
            pred_spans = self.crf.tags_to_spans(pred_tags)
            additional_outputs["spans"] = pred_spans
        return loss

crf = MultiLabelBIOULDecoder(1, learnable_transitions=False) instance-attribute

classifier = None instance-attribute

__init__(input_size=None, n_labels=None, mode=CRFMode.joint)

Nested NER CRF module

PARAMETER DESCRIPTION
input_size

Size of the input embeddings

TYPE: Optional[int] DEFAULT: None

n_labels

Number of labels predicted by the module

TYPE: Optional[int] DEFAULT: None

mode

Loss mode of the CRF

TYPE: CRFMode DEFAULT: CRFMode.joint

Source code in edsnlp/models/stack_crf_ner.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    input_size: Optional[int] = None,
    n_labels: Optional[int] = None,
    mode: CRFMode = CRFMode.joint,
):
    """
    Nested NER CRF module

    Parameters
    ----------
    input_size: int
        Size of the input embeddings
    n_labels: int
        Number of labels predicted by the module
    mode: CRFMode
        Loss mode of the CRF
    """
    super().__init__(input_size, n_labels)

    self.cfg["mode"] = mode

    assert mode in (CRFMode.independent, CRFMode.joint, CRFMode.marginal)
    self.crf = MultiLabelBIOULDecoder(1, learnable_transitions=False)

    self.classifier = None

initialize()

Once the number of labels n_labels are known, this method initializes the torch linear layer.

Source code in edsnlp/models/stack_crf_ner.py
67
68
69
70
71
72
73
def initialize(self):
    """
    Once the number of labels n_labels are known, this method
    initializes the torch linear layer.
    """
    num_tags = self.n_labels * self.crf.num_tags
    self.classifier = torch.nn.Linear(self.input_size, num_tags)

forward(embeds, mask, spans=None, additional_outputs=None, is_train=False, is_predict=False)

Apply the nested ner module to the document embeddings to: - compute the loss - predict the spans non exclusively. If spans are predicted, they are assigned to the additional_outputs dictionary.

PARAMETER DESCRIPTION
embeds

Token embeddings to predict the tags from

TYPE: torch.FloatTensor

mask

Mask of the sequences

TYPE: torch.BoolTensor

spans

2d tensor of n_spans * (doc_idx, label_idx, begin, end)

TYPE: Optional[torch.LongTensor] DEFAULT: None

additional_outputs

Additional outputs that should not / cannot be back-propped through (Thinc treats Pytorch models solely as derivable functions, but the CRF that we employ performs the best tag decoding function with Pytorch) This dict will contain the predicted 2d tensor of spans

TYPE: Dict[str, Any] DEFAULT: None

is_train

Are we training the model (defaults to True)

TYPE: bool DEFAULT: False

is_predict

Are we predicting the model (defaults to False)

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Optional[torch.FloatTensor]

Optional 0d loss (shape = [1]) to train the model

Source code in edsnlp/models/stack_crf_ner.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def forward(
    self,
    embeds: torch.FloatTensor,
    mask: torch.BoolTensor,
    spans: Optional[torch.LongTensor] = None,
    additional_outputs: Dict[str, Any] = None,
    is_train: bool = False,
    is_predict: bool = False,
) -> Optional[torch.FloatTensor]:
    """
    Apply the nested ner module to the document embeddings to:
    - compute the loss
    - predict the spans
    non exclusively.
    If spans are predicted, they are assigned to the `additional_outputs`
    dictionary.

    Parameters
    ----------
    embeds: torch.FloatTensor
        Token embeddings to predict the tags from
    mask: torch.BoolTensor
        Mask of the sequences
    spans: Optional[torch.LongTensor]
        2d tensor of n_spans * (doc_idx, label_idx, begin, end)
    additional_outputs: Dict[str, Any]
        Additional outputs that should not / cannot be back-propped through
        (Thinc treats Pytorch models solely as derivable functions, but the CRF
        that we employ performs the best tag decoding function with Pytorch)
        This dict will contain the predicted 2d tensor of spans
    is_train: bool=False
        Are we training the model (defaults to True)
    is_predict: bool=False
        Are we predicting the model (defaults to False)

    Returns
    -------
    Optional[torch.FloatTensor]
        Optional 0d loss (shape = [1]) to train the model
    """
    n_samples, n_tokens = embeds.shape[:2]
    logits = self.classifier(embeds)
    crf_logits = flatten_dim(
        logits.view(n_samples, n_tokens, self.n_labels, self.crf.num_tags).permute(
            0, 2, 1, 3
        ),
        dim=0,
    )
    crf_mask = repeat(mask, self.n_labels, 0)
    loss = None
    if is_train:
        tags = self.crf.spans_to_tags(
            spans, n_samples=n_samples, n_tokens=n_tokens, n_labels=self.n_labels
        )
        crf_target = flatten_dim(
            torch.nn.functional.one_hot(tags, 5).bool()
            if len(tags.shape) == 3
            else tags,
            dim=0,
        )
        if self.cfg["mode"] == CRFMode.joint:
            loss = self.crf(
                crf_logits,
                crf_mask,
                crf_target,
            )
        elif self.cfg["mode"] == CRFMode.independent:
            loss = (
                -crf_logits.log_softmax(-1)
                .masked_fill(~crf_target, IMPOSSIBLE)
                .logsumexp(-1)[crf_mask]
                .sum()
            )
        elif self.cfg["mode"] == CRFMode.marginal:
            crf_logits = self.crf.marginal(
                crf_logits,
                crf_mask,
            )
            loss = (
                -crf_logits.log_softmax(-1)
                .masked_fill(~crf_target, IMPOSSIBLE)
                .logsumexp(-1)[crf_mask]
                .sum()
            )
        if (loss > -IMPOSSIBLE).any():
            logger.warning(
                "You likely have an impossible transition in your "
                "training data NER tags, skipping this batch."
            )
            loss = torch.zeros(1, dtype=torch.float, device=embeds.device)
        loss = loss.sum().unsqueeze(0) / 100.0
    if is_predict:
        pred_tags = self.crf.decode(crf_logits, crf_mask).reshape(
            n_samples, self.n_labels, n_tokens
        )
        pred_spans = self.crf.tags_to_spans(pred_tags)
        additional_outputs["spans"] = pred_spans
    return loss

repeat(t, n, dim, interleave=True)

Source code in edsnlp/models/stack_crf_ner.py
21
22
23
24
25
26
27
28
29
30
31
32
def repeat(t, n, dim, interleave=True):
    repeat_dim = dim + (1 if interleave else 0)
    return (
        t.unsqueeze(repeat_dim)
        .repeat_interleave(n, repeat_dim)
        .view(
            tuple(
                -1 if (i - dim + t.ndim) % t.ndim == 0 else s
                for i, s in enumerate(t.shape)
            )
        )
    )

flatten_dim(t, dim, ndim=1)

Source code in edsnlp/models/stack_crf_ner.py
35
36
def flatten_dim(t, dim, ndim=1):
    return t.reshape((*t.shape[:dim], -1, *t.shape[dim + 1 + ndim :]))

create_model(tok2vec, mode, n_labels=None)

Source code in edsnlp/models/stack_crf_ner.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@registry.layers("eds.stack_crf_ner_model.v1")
def create_model(
    tok2vec: Model[List[Doc], List[Floats2d]],
    mode: CRFMode,
    n_labels: int = None,
) -> Model[
    Tuple[Iterable[Doc], Optional[Ints2d], Optional[bool]],
    Tuple[Floats1d, Ints2d],
]:
    return wrap_pytorch_model(  # noqa
        encoder=tok2vec,
        pt_model=StackedCRFNERModule(
            input_size=None,  # will be set later during initialization
            n_labels=n_labels,  # will likely be set later during initialization
            mode=mode,
        ),
    )