Skip to content

edspdf.layers.box_embedding

BoxEmbedding

Bases: Module

Source code in edspdf/layers/box_embedding.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@registry.factory.register("box-embedding")
class BoxEmbedding(Module):
    def __init__(
        self,
        size: int,
        dropout_p: float = 0.2,
        layout_encoder: Optional[Module] = {},
        text_encoder: Optional[Module] = None,
        contextualizer: Optional[Module] = None,
    ):
        """
        Encodes boxes using a combination of layout and text features.

        Parameters
        ----------
        size: int
            Size of the output box embedding
        dropout_p: float
            Dropout probability used on the output of the box and textual encoders
        text_encoder: Dict
            Config for the text encoder
        layout_encoder: Dict
            Config for the layout encoder
        """
        super().__init__()

        assert size % 6 == 0, "Embedding dimension must be dividable by 6"

        self.size = size

        self.layout_encoder = layout_encoder
        self.text_encoder = text_encoder
        self.contextualizer = contextualizer
        self.dropout = torch.nn.Dropout(dropout_p)

    @property
    def output_size(self):
        return self.size

    def initialize(self, gold_data, **kwargs):
        super().initialize(gold_data, **kwargs)
        if self.text_encoder is not None:
            self.text_encoder.initialize(gold_data, size=self.size)
        if self.layout_encoder is not None:
            self.layout_encoder.initialize(gold_data, size=self.size)
        if self.contextualizer is not None:
            self.contextualizer.initialize(gold_data, input_size=self.size)

    def preprocess(self, doc, supervision: bool = False):
        return {
            "boxes": self.layout_encoder.preprocess(doc, supervision=supervision)
            if self.layout_encoder is not None
            else None,
            "texts": self.text_encoder.preprocess(doc, supervision=supervision)
            if self.text_encoder is not None
            else None,
        }

    def collate(self, batch, device: torch.device):
        return {
            "texts": self.text_encoder.collate(batch["texts"], device)
            if self.text_encoder is not None
            else None,
            "boxes": self.layout_encoder.collate(batch["boxes"], device)
            if self.layout_encoder is not None
            else None,
        }

    def forward(self, batch, supervision=False):
        embeds = sum(
            [
                self.dropout(encoder.module_forward(batch[name]))
                for name, encoder in (
                    ("boxes", self.layout_encoder),
                    ("texts", self.text_encoder),
                )
                if encoder is not None
            ]
        )

        if self.contextualizer is not None:
            embeds = self.contextualizer(
                embeds=embeds,
                boxes=batch["boxes"],
            )

        return embeds

__init__(size, dropout_p=0.2, layout_encoder={}, text_encoder=None, contextualizer=None)

Encodes boxes using a combination of layout and text features.

PARAMETER DESCRIPTION
size

Size of the output box embedding

TYPE: int

dropout_p

Dropout probability used on the output of the box and textual encoders

TYPE: float DEFAULT: 0.2

text_encoder

Config for the text encoder

TYPE: Optional[Module] DEFAULT: None

layout_encoder

Config for the layout encoder

TYPE: Optional[Module] DEFAULT: {}

Source code in edspdf/layers/box_embedding.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    size: int,
    dropout_p: float = 0.2,
    layout_encoder: Optional[Module] = {},
    text_encoder: Optional[Module] = None,
    contextualizer: Optional[Module] = None,
):
    """
    Encodes boxes using a combination of layout and text features.

    Parameters
    ----------
    size: int
        Size of the output box embedding
    dropout_p: float
        Dropout probability used on the output of the box and textual encoders
    text_encoder: Dict
        Config for the text encoder
    layout_encoder: Dict
        Config for the layout encoder
    """
    super().__init__()

    assert size % 6 == 0, "Embedding dimension must be dividable by 6"

    self.size = size

    self.layout_encoder = layout_encoder
    self.text_encoder = text_encoder
    self.contextualizer = contextualizer
    self.dropout = torch.nn.Dropout(dropout_p)