Skip to content

edspdf.layers.box_layout_embedding

BoxLayoutEmbedding

Bases: Module

Encodes a box using its geometrical features, as extracted by the BoxLayoutPreprocessor module.

Source code in edspdf/layers/box_layout_embedding.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@registry.factory.register("box-layout-embedding")
class BoxLayoutEmbedding(Module):
    """
    Encodes a box using its geometrical features, as extracted by the
    BoxLayoutPreprocessor module.
    """

    def __init__(
        self,
        n_positions: int,
        size: Optional[int] = None,
        x_mode: PositionEmbeddingMode = "sin",
        y_mode: PositionEmbeddingMode = "sin",
        w_mode: PositionEmbeddingMode = "sin",
        h_mode: PositionEmbeddingMode = "sin",
    ):
        """

        Parameters
        ----------
        size: int
            Size of the output box embedding
        n_positions: int
            Number of position embeddings stored in the PositionEmbedding module
        x_mode: PositionEmbeddingMode
            Position embedding mode of the x coordinates
        y_mode: PositionEmbeddingMode
            Position embedding mode of the x coordinates
        w_mode: PositionEmbeddingMode
            Position embedding mode of the width features
        h_mode: PositionEmbeddingMode
            Position embedding mode of the height features
        """

        super().__init__()

        self.n_positions = n_positions
        self.size = size
        self.x_mode = x_mode
        self.y_mode = y_mode
        self.w_mode = w_mode
        self.h_mode = h_mode
        self.x_embedding = None
        self.y_embedding = None
        self.w_embedding = None
        self.h_embedding = None
        self.first_page_embedding = None
        self.last_page_embedding = None

        self.box_preprocessor = BoxLayoutPreprocessor()
        self.preprocess = self.box_preprocessor.preprocess
        self.collate = self.box_preprocessor.collate

    def initialize(self, gold_data: Iterable, size: int = None, **kwargs):
        super().initialize(gold_data, size=size, **kwargs)
        n_pos, size = self.n_positions, self.size

        self.x_embedding = self._make_embed(n_pos, size // 6, self.x_mode)
        self.y_embedding = self._make_embed(n_pos, size // 6, self.y_mode)
        self.w_embedding = self._make_embed(n_pos, size // 6, self.w_mode)
        self.h_embedding = self._make_embed(n_pos, size // 6, self.h_mode)
        self.first_page_embedding = torch.nn.Parameter(torch.randn(self.size))
        self.last_page_embedding = torch.nn.Parameter(torch.randn(self.size))

    @classmethod
    def _make_embed(cls, n_positions, size, mode):
        if mode == "sin":
            return SinusoidalEmbedding(n_positions, size)
        else:
            return torch.nn.Embedding(n_positions, size)

    def forward(self, batch):
        return (
            torch.cat(
                [
                    self.x_embedding(
                        (batch["xmin"] * self.n_positions)
                        .clamp(max=self.n_positions - 1)
                        .long()
                    ),
                    self.y_embedding(
                        (batch["ymin"] * self.n_positions)
                        .clamp(max=self.n_positions - 1)
                        .long()
                    ),
                    self.x_embedding(
                        (batch["xmax"] * self.n_positions)
                        .clamp(max=self.n_positions - 1)
                        .long()
                    ),
                    self.y_embedding(
                        (batch["ymax"] * self.n_positions)
                        .clamp(max=self.n_positions - 1)
                        .long()
                    ),
                    self.w_embedding(
                        (batch["width"] * self.n_positions)
                        .clamp(max=self.n_positions - 1)
                        .long()
                    ),
                    self.h_embedding(
                        (batch["height"] * 5 * self.n_positions)
                        .clamp(max=self.n_positions - 1)
                        .long()
                    ),
                ],
                dim=-1,
            )
            + self.first_page_embedding * batch["first_page"][..., None]
            + self.last_page_embedding * batch["last_page"][..., None]
        )

__init__(n_positions, size=None, x_mode='sin', y_mode='sin', w_mode='sin', h_mode='sin')

PARAMETER DESCRIPTION
size

Size of the output box embedding

TYPE: Optional[int] DEFAULT: None

n_positions

Number of position embeddings stored in the PositionEmbedding module

TYPE: int

x_mode

Position embedding mode of the x coordinates

TYPE: PositionEmbeddingMode DEFAULT: 'sin'

y_mode

Position embedding mode of the x coordinates

TYPE: PositionEmbeddingMode DEFAULT: 'sin'

w_mode

Position embedding mode of the width features

TYPE: PositionEmbeddingMode DEFAULT: 'sin'

h_mode

Position embedding mode of the height features

TYPE: PositionEmbeddingMode DEFAULT: 'sin'

Source code in edspdf/layers/box_layout_embedding.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    n_positions: int,
    size: Optional[int] = None,
    x_mode: PositionEmbeddingMode = "sin",
    y_mode: PositionEmbeddingMode = "sin",
    w_mode: PositionEmbeddingMode = "sin",
    h_mode: PositionEmbeddingMode = "sin",
):
    """

    Parameters
    ----------
    size: int
        Size of the output box embedding
    n_positions: int
        Number of position embeddings stored in the PositionEmbedding module
    x_mode: PositionEmbeddingMode
        Position embedding mode of the x coordinates
    y_mode: PositionEmbeddingMode
        Position embedding mode of the x coordinates
    w_mode: PositionEmbeddingMode
        Position embedding mode of the width features
    h_mode: PositionEmbeddingMode
        Position embedding mode of the height features
    """

    super().__init__()

    self.n_positions = n_positions
    self.size = size
    self.x_mode = x_mode
    self.y_mode = y_mode
    self.w_mode = w_mode
    self.h_mode = h_mode
    self.x_embedding = None
    self.y_embedding = None
    self.w_embedding = None
    self.h_embedding = None
    self.first_page_embedding = None
    self.last_page_embedding = None

    self.box_preprocessor = BoxLayoutPreprocessor()
    self.preprocess = self.box_preprocessor.preprocess
    self.collate = self.box_preprocessor.collate