Skip to content

BoxTransformerLayer

BoxTransformerLayer combining a self attention layer and a linear->activation->linear transformation. This layer is used in the BoxTransformerModule module.

Parameters

PARAMETER DESCRIPTION
input_size

Input embedding size

TYPE: int

num_heads

Number of attention heads in the attention layer

TYPE: int DEFAULT: 2

dropout_p

Dropout probability both for the attention layer and embedding projections

TYPE: float DEFAULT: 0.0

head_size

Head sizes of the attention layer

TYPE: Optional[int] DEFAULT: None

activation

Activation function used in the linear->activation->linear transformation

TYPE: ActivationFunction DEFAULT: 'gelu'

init_resweight

Initial weight of the residual gates. At 0, the layer acts (initially) as an identity function, and at 1 as a standard Transformer layer. Initializing with a value close to 0 can help the training converge.

TYPE: float DEFAULT: 0.0

attention_mode

Mode of relative position infused attention layer. See the relative attention documentation for more information.

TYPE: Sequence[Literal['c2c', 'c2p', 'p2c']] DEFAULT: ('c2c', 'c2p', 'p2c')

position_embedding

Position embedding to use as key/query position embedding in the attention computation.

TYPE: Optional[Union[FloatTensor, Parameter]] DEFAULT: None

forward

Forward pass of the BoxTransformerLayer

PARAMETER DESCRIPTION
embeds

Embeddings to contextualize Shape: n_samples * n_keys * input_size

TYPE: FloatTensor

mask

Mask of the embeddings. 0 means padding element. Shape: n_samples * n_keys

TYPE: BoolTensor

relative_positions

Position of the keys relatively to the query elements Shape: n_samples * n_queries * n_keys * n_coordinates (2 for x/y)

TYPE: LongTensor

no_position_mask

Key / query pairs for which the position attention terms should be disabled. Shape: n_samples * n_queries * n_keys

TYPE: Optional[BoolTensor] DEFAULT: None

RETURNS DESCRIPTION
Tuple[FloatTensor, FloatTensor]
  • Contextualized embeddings Shape: n_samples * n_queries * n_keys
  • Attention logits Shape: n_samples * n_queries * n_keys * n_heads