Skip to content

edspdf.components.classifiers.deep_classifier

DeepClassifier

Bases: TrainableComponent[PDFDoc, Dict[str, Any], PDFDoc]

Source code in edspdf/components/classifiers/deep_classifier.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@registry.factory.register("deep-classifier")
class DeepClassifier(TrainableComponent[PDFDoc, Dict[str, Any], PDFDoc]):
    def __init__(
        self,
        embedding: Module,
        labels: Sequence[str] = (),
        activation: ActivationFunction = "gelu",
        dropout_p: float = 0.15,
        scorer: Scorer = classifier_scorer,
    ):
        """
        Runs a deep learning classifier model on the boxes.

        Parameters
        ----------
        labels: Sequence[str]
            Initial labels of the classifier (will be completed during initialization)
        embedding: Module
            Embedding module to encode the PDF boxes
        activation: ActivationFunction
            Name of the activation function
        dropout_p: float
            Dropout probability used on the output of the box and textual encoders
        scorer: Scorer
            Scoring function
        """
        super().__init__()
        self.label_vocabulary: Vocabulary = Vocabulary(
            list(dict.fromkeys(["pollution", *labels]))
        )
        self.embedding: Module = embedding

        size = self.embedding.output_size

        self.linear = torch.nn.Linear(size, size)
        self.classifier: torch.nn.Linear = None  # noqa
        self.activation = get_activation_function(activation)
        self.dropout = torch.nn.Dropout(dropout_p)

        # Scoring function
        self.scorer = scorer

    def initialize(self, gold_data: Iterable[PDFDoc]):
        self.embedding.initialize(gold_data)

        with self.label_vocabulary.initialization():
            for doc in tqdm(gold_data, desc="Initializing classifier"):
                with self.no_cache():
                    self.preprocess(doc, supervision=True)

        self.classifier = torch.nn.Linear(
            in_features=self.embedding.output_size,
            out_features=len(self.label_vocabulary),
        )

    def preprocess(self, doc: PDFDoc, supervision: bool = False) -> Dict[str, Any]:
        result = {
            "embedding": self.embedding.preprocess(doc, supervision=supervision),
            "doc_id": doc.id,
        }
        if supervision:
            text_boxes = doc.lines
            result["labels"] = [
                self.label_vocabulary.encode(b.label) if b.label is not None else -100
                for b in text_boxes
            ]
        return result

    def collate(self, batch, device: torch.device) -> Dict:
        collated = {
            "embedding": self.embedding.collate(batch["embedding"], device),
            "doc_id": batch["doc_id"],
        }
        if "labels" in batch:
            collated.update(
                {
                    "labels": torch.as_tensor(flatten(batch["labels"]), device=device),
                }
            )

        return collated

    def forward(self, batch: Dict, supervision=False) -> Dict:
        embeds = self.embedding(batch["embedding"])

        output = {"loss": 0}

        # Label prediction / learning
        logits = self.classifier(embeds)
        if supervision:
            targets = batch["labels"]
            output["label_loss"] = F.cross_entropy(
                logits,
                targets,
                reduction="sum",
            )
            output["loss"] = output["loss"] + output["label_loss"]
        else:
            output["logits"] = logits
            output["labels"] = logits.argmax(-1)

        return output

    def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]:
        for b, label in zip(
            (b for doc in docs for b in doc.lines),
            output["labels"].cpu().tolist(),
        ):
            if b.text == "":
                b.label = None
            else:
                b.label = self.label_vocabulary.decode(label)
        return docs

__init__(embedding, labels=(), activation='gelu', dropout_p=0.15, scorer=classifier_scorer)

Runs a deep learning classifier model on the boxes.

PARAMETER DESCRIPTION
labels

Initial labels of the classifier (will be completed during initialization)

TYPE: Sequence[str] DEFAULT: ()

embedding

Embedding module to encode the PDF boxes

TYPE: Module

activation

Name of the activation function

TYPE: ActivationFunction DEFAULT: 'gelu'

dropout_p

Dropout probability used on the output of the box and textual encoders

TYPE: float DEFAULT: 0.15

scorer

Scoring function

TYPE: Scorer DEFAULT: classifier_scorer

Source code in edspdf/components/classifiers/deep_classifier.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    embedding: Module,
    labels: Sequence[str] = (),
    activation: ActivationFunction = "gelu",
    dropout_p: float = 0.15,
    scorer: Scorer = classifier_scorer,
):
    """
    Runs a deep learning classifier model on the boxes.

    Parameters
    ----------
    labels: Sequence[str]
        Initial labels of the classifier (will be completed during initialization)
    embedding: Module
        Embedding module to encode the PDF boxes
    activation: ActivationFunction
        Name of the activation function
    dropout_p: float
        Dropout probability used on the output of the box and textual encoders
    scorer: Scorer
        Scoring function
    """
    super().__init__()
    self.label_vocabulary: Vocabulary = Vocabulary(
        list(dict.fromkeys(["pollution", *labels]))
    )
    self.embedding: Module = embedding

    size = self.embedding.output_size

    self.linear = torch.nn.Linear(size, size)
    self.classifier: torch.nn.Linear = None  # noqa
    self.activation = get_activation_function(activation)
    self.dropout = torch.nn.Dropout(dropout_p)

    # Scoring function
    self.scorer = scorer