Skip to content

edspdf.utils.torch

compute_pdf_relative_positions(x0, y0, x1, y1, width, height, n_relative_positions)

Compute relative positions between boxes. Input boxes must be split between pages with the shape n_pages * n_boxes

PARAMETER DESCRIPTION
x0

y0

x1

y1

width

height

n_relative_positions

Maximum range of embeddable relative positions between boxes (further distances will be capped to ±n_relative_positions // 2)

RETURNS DESCRIPTION
torch.LongTensor

Shape: n_pages * n_boxes * n_boxes * 2

Source code in edspdf/utils/torch.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def compute_pdf_relative_positions(x0, y0, x1, y1, width, height, n_relative_positions):
    """
    Compute relative positions between boxes.
    Input boxes must be split between pages with the shape n_pages * n_boxes

    Parameters
    ----------
    x0: torch.FloatTensor
    y0: torch.FloatTensor
    x1: torch.FloatTensor
    y1: torch.FloatTensor
    width: torch.FloatTensor
    height: torch.FloatTensor
    n_relative_positions: int
        Maximum range of embeddable relative positions between boxes (further
        distances will be capped to ±n_relative_positions // 2)

    Returns
    -------
    torch.LongTensor
        Shape: n_pages * n_boxes * n_boxes * 2
    """
    dx = x0[:, None, :] - x0[:, :, None]  # B begin -> A begin
    dx = (dx * n_relative_positions).long()

    dy = y0[:, None, :] - y0[:, :, None]
    # If query above (dy > 0) key, use query height
    ref_height = (dy >= 0).float() * height.float()[:, :, None] + (
        dy < 0
    ).float() * height[:, None, :]
    dy0 = y1[:, None, :] - y0[:, :, None]  # A begin -> B end
    dy1 = y0[:, None, :] - y1[:, :, None]  # A end -> B begin
    offset = 0.5
    dy = torch.where(
        # where A fully above B (dy0 and dy1 > 0), dy is min distance
        ((dy0 + offset).sign() > 0) & ((dy1 + offset).sign() > 0),
        (torch.minimum(dy0, dy1) / ref_height + offset).ceil(),
        # where A fully below B (dy0 and dy1 < 0), dy is -(min -distances)
        torch.where(
            ((dy0 - offset).sign() < 0) & ((dy1 - offset).sign() < 0),
            (torch.maximum(dy0, dy1) / ref_height - offset).floor(),
            0,
        ),
    )
    dy = (dy.abs().ceil() * dy.sign()).long()

    relative_positions = torch.stack([dx, dy], dim=-1)

    return relative_positions

log_einsum_exp(formula, *ops)

Numerically stable log of einsum of exponents of operands

Source code in edspdf/utils/torch.py
76
77
78
79
80
81
82
83
84
def log_einsum_exp(formula, *ops):
    """
    Numerically stable log of einsum of exponents of operands
    """
    maxes = [op.max() for op in ops]
    ops = [op - op_max for op, op_max in zip(ops, maxes)]
    res = torch.einsum(formula, *(op.exp() for op in ops)).log()
    res = res + sum(maxes)
    return res