25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137 | @registry.factory.register("deep-classifier")
class DeepClassifier(TrainableComponent[PDFDoc, Dict[str, Any], PDFDoc]):
def __init__(
self,
embedding: Module,
labels: Sequence[str] = (),
activation: ActivationFunction = "gelu",
dropout_p: float = 0.15,
scorer: Scorer = classifier_scorer,
):
"""
Runs a deep learning classifier model on the boxes.
Parameters
----------
labels: Sequence[str]
Initial labels of the classifier (will be completed during initialization)
embedding: Module
Embedding module to encode the PDF boxes
activation: ActivationFunction
Name of the activation function
dropout_p: float
Dropout probability used on the output of the box and textual encoders
scorer: Scorer
Scoring function
"""
super().__init__()
self.label_vocabulary: Vocabulary = Vocabulary(
list(dict.fromkeys(["pollution", *labels]))
)
self.embedding: Module = embedding
size = self.embedding.output_size
self.linear = torch.nn.Linear(size, size)
self.classifier: torch.nn.Linear = None # noqa
self.activation = get_activation_function(activation)
self.dropout = torch.nn.Dropout(dropout_p)
# Scoring function
self.scorer = scorer
def initialize(self, gold_data: Iterable[PDFDoc]):
self.embedding.initialize(gold_data)
with self.label_vocabulary.initialization():
for doc in tqdm(gold_data, desc="Initializing classifier"):
with self.no_cache():
self.preprocess(doc, supervision=True)
self.classifier = torch.nn.Linear(
in_features=self.embedding.output_size,
out_features=len(self.label_vocabulary),
)
def preprocess(self, doc: PDFDoc, supervision: bool = False) -> Dict[str, Any]:
result = {
"embedding": self.embedding.preprocess(doc, supervision=supervision),
"doc_id": doc.id,
}
if supervision:
text_boxes = doc.lines
result["labels"] = [
self.label_vocabulary.encode(b.label) if b.label is not None else -100
for b in text_boxes
]
return result
def collate(self, batch, device: torch.device) -> Dict:
collated = {
"embedding": self.embedding.collate(batch["embedding"], device),
"doc_id": batch["doc_id"],
}
if "labels" in batch:
collated.update(
{
"labels": torch.as_tensor(flatten(batch["labels"]), device=device),
}
)
return collated
def forward(self, batch: Dict, supervision=False) -> Dict:
embeds = self.embedding(batch["embedding"])
output = {"loss": 0}
# Label prediction / learning
logits = self.classifier(embeds)
if supervision:
targets = batch["labels"]
output["label_loss"] = F.cross_entropy(
logits,
targets,
reduction="sum",
)
output["loss"] = output["loss"] + output["label_loss"]
else:
output["logits"] = logits
output["labels"] = logits.argmax(-1)
return output
def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]:
for b, label in zip(
(b for doc in docs for b in doc.lines),
output["labels"].cpu().tolist(),
):
if b.text == "":
b.label = None
else:
b.label = self.label_vocabulary.decode(label)
return docs
|