Skip to content

edspdf.pipeline

Pipeline

The Pipeline is the core object of EDS-PDF. It is responsible for the orchestration of the components and processing PDF documents end-to-end.

A pipeline is usually created empty and then populated with components via the add_pipe method. Here is an example :

pipeline = Pipeline()
pipeline.add_pipe("pdfminer-extractor")
pipeline.add_pipe("mask-classifier")
pipeline.add_pipe("simple-aggregator")
Source code in edspdf/pipeline.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
@validate_arguments
class Pipeline:
    """
    The Pipeline is the core object of EDS-PDF. It is responsible for the
    orchestration of the components and processing PDF documents end-to-end.

    A pipeline is usually created empty and then populated with components via the
    `add_pipe` method. Here is an example :

    ```python
    pipeline = Pipeline()
    pipeline.add_pipe("pdfminer-extractor")
    pipeline.add_pipe("mask-classifier")
    pipeline.add_pipe("simple-aggregator")
    ```
    """

    def __init__(
        self,
        components: Optional[List[str]] = None,
        components_config: Optional[Dict[str, Component]] = None,
        batch_size: int = 4,
    ):
        """
        Initializes the pipeline. The pipeline is empty by default and can be
        populated with components via the `add_pipe` method.

        Parameters
        ----------
        components: Optional[List[str]]
            List of component names
        components_config: Optional[Dict[str, Component]]
            Dictionary of component configurations. The keys of the dictionary must
            match the component names. The values are the component configurations,
            which can contain unresolved configuration for nested components or
            instances of components.
        batch_size: int
            The default number of documents to process in parallel when running the
            pipeline
        """
        super().__init__()
        if components is None:
            components = ComponentsMap({})
        if components_config is None:
            components_config = {}
        self.components = ComponentsMap({})
        for name in components:
            component = components_config[name]
            component.name = name
            self.components[name] = component

        self.batch_size = batch_size
        self.meta = {}

    def add_pipe(self, factory_name: Union[str, Callable], name=None, config=None):
        """
        Adds a component to the pipeline. The component can be either a factory name or
        an instantiated component. If a factory name is provided, the component will be
        instantiated with the class from the registry matching the factory name and
        using the provided config as arguments.

        Parameters
        ----------
        factory_name: Union[str, Callable]
            Either a factory name or an instantiated component
        name: str
            Name of the component
        config: Dict
            Configuration of the component. The configuration can contain unresolved
            configuration for nested components such as
            `{"@factory_name": "my-sub-component", ...}`

        Returns
        -------
        Component
            The added component
        """
        if isinstance(factory_name, str):
            if config is None:
                config = {}
            # config = Config(config).resolve()
            cls = registry.factory.get(factory_name)
            component = cls(**config)
        elif hasattr(factory_name, "__call__") or hasattr(factory_name, "process"):
            if config is not None:
                raise TypeError(
                    "Cannot provide both an instantiated component and its config"
                )
            component = factory_name
            factory_name = next(
                k
                for k, v in registry.factory.get_all().items()
                if component.__class__ == v
            )
        else:
            raise TypeError(
                "`add_pipe` first argument must either be a factory name "
                f"or an instantiated component. You passed a {type(factory_name)}"
            )
        if name is None:
            if factory_name is None:
                raise TypeError(
                    "Could not automatically assign a name for component {}: either"
                    "provide a name explicitly to the add_pipe function, or define a "
                    "factory_name field on this component.".format(component)
                )
            name = factory_name

        self.components[name] = component
        component.name = name

        if not (hasattr(component, "process") or hasattr(component, "__call__")):
            raise TypeError("Component must have a process method or be callable")

        return component

    def reset_cache(self, cache: Optional[CacheEnum] = None):
        """
        Reset the caches of the components in this pipeline

        Parameters
        ----------
        cache: Optional[CacheEnum]
            The cache to reset (either `preprocess`, `collate` or `forward`)
            If None, all caches are reset
        """
        for component in self.components.values():
            try:
                component.reset_cache(cache)
            except AttributeError:
                pass

    def __call__(self, doc: InputT) -> OutputT:
        """
        Applies the pipeline on a sample

        Parameters
        ----------
        doc: InputT
            The document to process

        Returns
        -------
        OutputT
        """
        self.reset_cache()
        for name, component in self.components.items():
            doc = component(doc)
        return doc

    def pipe(self, docs: Iterable[InputT]) -> Iterable:
        """
        Apply the pipeline on a collection of documents

        Parameters
        ----------
        docs: Iterable[InputT]
            The documents to process

        Returns
        -------
        Iterable
            An iterable collection of processed documents
        """
        for batch in batchify(docs, batch_size=self.batch_size):
            self.reset_cache()
            for component in self.components.values():
                batch = component.batch_process(batch)
            yield from batch

    def initialize(self, data: Iterable[InputT]):
        """
        Initialize the components of the pipeline
        Each component must be initialized before the next components are run.
        Since a component might need the full training data to be initialized, all
        data may be fed to the component, making it impossible to enable batch caching.

        Therefore, we disable cache during the entire operation, so heavy computation
        (such as embeddings) that is usually shared will be repeated for each
        initialized component.

        Parameters
        ----------
        data: SupervisedData

        """
        # Component initialization
        print("Initializing components")
        data = multi_tee(data)

        with self.no_cache():
            for name, component in self.components.items():
                if not component.initialized:
                    component.initialize(data)
                    print(f"Component {repr(name)} initialized")

    def score(self, docs: Sequence[InputT]):
        """
        Scores a pipeline against a sequence of annotated documents

        Parameters
        ----------
        docs: Sequence[InputT]
            The documents to score

        Returns
        -------
        Dict[str, Any]
            A dictionary containing the metrics of the pipeline, as well as the speed of
            the pipeline. Each component that has a scorer will also be scored and its
            metrics will be included in the returned dictionary under a key named after
            each component.
        """
        self.train(False)
        inputs: Sequence[InputT] = copy.deepcopy(docs)
        golds: Iterable[Dict[str, InputT]] = docs

        scored_components = {}

        # Predicting intermediate steps
        preds = defaultdict(lambda: [])
        for batch in batchify(
            tqdm(inputs, "Scoring components"), batch_size=self.batch_size
        ):
            self.reset_cache()
            for name, component in self.components.items():
                if component.scorer is not None:
                    scored_components[name] = component
                    batch = component.batch_process(batch)
                    preds[name].extend(copy.deepcopy(batch))

        t0 = time.time()
        for _ in tqdm(self.pipe(inputs), "Scoring pipeline", total=len(inputs)):
            pass
        duration = time.time() - t0

        # Scoring
        metrics: Dict[str, Any] = {
            "speed": len(inputs) / duration,
        }
        for name, component in scored_components.items():
            metrics[name] = component.score(list(zip(preds[name], golds)))
        return metrics

    def preprocess(self, doc: InputT, supervision: bool = False):
        """
        Runs the preprocessing methods of each component in the pipeline
        on a document and returns a dictionary containing the results, with the
        component names as keys.

        Parameters
        ----------
        doc: InputT
            The document to preprocess
        supervision: bool
            Whether to include supervision information in the preprocessing

        Returns
        -------
        Dict[str, Any]
        """
        prep = {}
        for name, component in self.components.items():
            if isinstance(component, TrainableComponent):
                prep[name] = component.preprocess(doc, supervision=supervision)
        return prep

    def preprocess_many(self, docs: Iterable[InputT], compress=True, supervision=True):
        """
        Runs the preprocessing methods of each component in the pipeline on
        a collection of documents and returns an iterable of dictionaries containing
        the results, with the component names as keys.

        Parameters
        ----------
        docs: Iterable[InputT]
        compress: bool
            Whether to deduplicate identical preprocessing outputs of the results
            if multiple documents share identical subcomponents. This step is required
            to enable the cache mechanism when training or running the pipeline over a
            tabular datasets such as pyarrow tables that do not store referential
            equality information.
        supervision: bool
            Whether to include supervision information in the preprocessing

        Returns
        -------
        Iterable[OutputT]
        """
        preprocessed = map(partial(self.preprocess, supervision=supervision), docs)
        if compress:
            return batch_compress_dict(preprocessed)
        return preprocessed

    def collate(self, batch: Dict[str, Any], device: Optional[torch.device] = None):
        """
        Collates a batch of preprocessed samples into a single (maybe nested)
        dictionary of tensors by calling the collate method of each component.

        Parameters
        ----------
        batch: Dict[str, Any]
            The batch of preprocessed samples
        device: Optional[torch.device]
            The device to move the tensors to before returning them

        Returns
        -------
        Dict[str, Any]
            The collated batch
        """
        batch = decompress_dict(batch)
        if device is None:
            device = next(p.device for p in self.parameters())
        for name, component in self.components.items():
            if name in batch:
                component: TrainableComponent
                component_inputs = batch[name]
                batch[name] = component.collate(component_inputs, device)
        return batch

    def train(self, mode=True):
        """
        Enables training mode on pytorch modules

        Parameters
        ----------
        mode: bool
            Whether to enable training or not
        """
        for component in self.components.values():
            if hasattr(component, "train"):
                component.train(mode)

    @property
    def cfg(self):
        """Returns the initial configuration of the pipeline"""
        return Config(
            components=list(self.components.keys()),
            components_config=Config(**self.components, __path__=("components",)),
        ).serialize()

    @contextmanager
    def no_cache(self):
        """Disable caching for all (trainable) components in the pipeline"""
        saved = []
        for component in self.components.values():
            if isinstance(component, TrainableComponent):
                saved.append((component, component.enable_cache(False)))
        yield
        for component, do_cache in saved:
            component.enable_cache(do_cache)

    def parameters(self):
        """Returns an iterator over the Pytorch parameters of the components in the
        pipeline"""
        seen = set()
        for component in self.components.values():
            if isinstance(component, torch.nn.Module):
                for param in component.parameters():
                    if param in seen:
                        continue
                    seen.add(param)
                    yield param

    def __repr__(self):
        return "Pipeline({})".format(
            "\n{}\n".format(
                "\n".join(
                    indent(f"({name}): " + repr(component), prefix="  ")
                    for name, component in self.components.items()
                )
            )
            if len(self.components)
            else ""
        )

    @property
    def trainable_components(self) -> List[TrainableComponent]:
        """Returns the list of trainable components in the pipeline."""
        return [
            c
            for c in self.components.values()
            if isinstance(c, TrainableComponent) and c.needs_training
        ]

    def __iter__(self):
        """Returns an iterator over the components in the pipeline."""
        return iter(self.components.values())

    def __len__(self):
        """Returns the number of components in the pipeline."""
        return len(self.components)

cfg property

Returns the initial configuration of the pipeline

trainable_components: List[TrainableComponent] property

Returns the list of trainable components in the pipeline.

__init__(components=None, components_config=None, batch_size=4)

Initializes the pipeline. The pipeline is empty by default and can be populated with components via the add_pipe method.

PARAMETER DESCRIPTION
components

List of component names

TYPE: Optional[List[str]] DEFAULT: None

components_config

Dictionary of component configurations. The keys of the dictionary must match the component names. The values are the component configurations, which can contain unresolved configuration for nested components or instances of components.

TYPE: Optional[Dict[str, Component]] DEFAULT: None

batch_size

The default number of documents to process in parallel when running the pipeline

TYPE: int DEFAULT: 4

Source code in edspdf/pipeline.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    components: Optional[List[str]] = None,
    components_config: Optional[Dict[str, Component]] = None,
    batch_size: int = 4,
):
    """
    Initializes the pipeline. The pipeline is empty by default and can be
    populated with components via the `add_pipe` method.

    Parameters
    ----------
    components: Optional[List[str]]
        List of component names
    components_config: Optional[Dict[str, Component]]
        Dictionary of component configurations. The keys of the dictionary must
        match the component names. The values are the component configurations,
        which can contain unresolved configuration for nested components or
        instances of components.
    batch_size: int
        The default number of documents to process in parallel when running the
        pipeline
    """
    super().__init__()
    if components is None:
        components = ComponentsMap({})
    if components_config is None:
        components_config = {}
    self.components = ComponentsMap({})
    for name in components:
        component = components_config[name]
        component.name = name
        self.components[name] = component

    self.batch_size = batch_size
    self.meta = {}

add_pipe(factory_name, name=None, config=None)

Adds a component to the pipeline. The component can be either a factory name or an instantiated component. If a factory name is provided, the component will be instantiated with the class from the registry matching the factory name and using the provided config as arguments.

PARAMETER DESCRIPTION
factory_name

Either a factory name or an instantiated component

TYPE: Union[str, Callable]

name

Name of the component

DEFAULT: None

config

Configuration of the component. The configuration can contain unresolved configuration for nested components such as {"@factory_name": "my-sub-component", ...}

DEFAULT: None

RETURNS DESCRIPTION
Component

The added component

Source code in edspdf/pipeline.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def add_pipe(self, factory_name: Union[str, Callable], name=None, config=None):
    """
    Adds a component to the pipeline. The component can be either a factory name or
    an instantiated component. If a factory name is provided, the component will be
    instantiated with the class from the registry matching the factory name and
    using the provided config as arguments.

    Parameters
    ----------
    factory_name: Union[str, Callable]
        Either a factory name or an instantiated component
    name: str
        Name of the component
    config: Dict
        Configuration of the component. The configuration can contain unresolved
        configuration for nested components such as
        `{"@factory_name": "my-sub-component", ...}`

    Returns
    -------
    Component
        The added component
    """
    if isinstance(factory_name, str):
        if config is None:
            config = {}
        # config = Config(config).resolve()
        cls = registry.factory.get(factory_name)
        component = cls(**config)
    elif hasattr(factory_name, "__call__") or hasattr(factory_name, "process"):
        if config is not None:
            raise TypeError(
                "Cannot provide both an instantiated component and its config"
            )
        component = factory_name
        factory_name = next(
            k
            for k, v in registry.factory.get_all().items()
            if component.__class__ == v
        )
    else:
        raise TypeError(
            "`add_pipe` first argument must either be a factory name "
            f"or an instantiated component. You passed a {type(factory_name)}"
        )
    if name is None:
        if factory_name is None:
            raise TypeError(
                "Could not automatically assign a name for component {}: either"
                "provide a name explicitly to the add_pipe function, or define a "
                "factory_name field on this component.".format(component)
            )
        name = factory_name

    self.components[name] = component
    component.name = name

    if not (hasattr(component, "process") or hasattr(component, "__call__")):
        raise TypeError("Component must have a process method or be callable")

    return component

reset_cache(cache=None)

Reset the caches of the components in this pipeline

PARAMETER DESCRIPTION
cache

The cache to reset (either preprocess, collate or forward) If None, all caches are reset

TYPE: Optional[CacheEnum] DEFAULT: None

Source code in edspdf/pipeline.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def reset_cache(self, cache: Optional[CacheEnum] = None):
    """
    Reset the caches of the components in this pipeline

    Parameters
    ----------
    cache: Optional[CacheEnum]
        The cache to reset (either `preprocess`, `collate` or `forward`)
        If None, all caches are reset
    """
    for component in self.components.values():
        try:
            component.reset_cache(cache)
        except AttributeError:
            pass

__call__(doc)

Applies the pipeline on a sample

PARAMETER DESCRIPTION
doc

The document to process

TYPE: InputT

RETURNS DESCRIPTION
OutputT
Source code in edspdf/pipeline.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def __call__(self, doc: InputT) -> OutputT:
    """
    Applies the pipeline on a sample

    Parameters
    ----------
    doc: InputT
        The document to process

    Returns
    -------
    OutputT
    """
    self.reset_cache()
    for name, component in self.components.items():
        doc = component(doc)
    return doc

pipe(docs)

Apply the pipeline on a collection of documents

PARAMETER DESCRIPTION
docs

The documents to process

TYPE: Iterable[InputT]

RETURNS DESCRIPTION
Iterable

An iterable collection of processed documents

Source code in edspdf/pipeline.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def pipe(self, docs: Iterable[InputT]) -> Iterable:
    """
    Apply the pipeline on a collection of documents

    Parameters
    ----------
    docs: Iterable[InputT]
        The documents to process

    Returns
    -------
    Iterable
        An iterable collection of processed documents
    """
    for batch in batchify(docs, batch_size=self.batch_size):
        self.reset_cache()
        for component in self.components.values():
            batch = component.batch_process(batch)
        yield from batch

initialize(data)

Initialize the components of the pipeline Each component must be initialized before the next components are run. Since a component might need the full training data to be initialized, all data may be fed to the component, making it impossible to enable batch caching.

Therefore, we disable cache during the entire operation, so heavy computation (such as embeddings) that is usually shared will be repeated for each initialized component.

PARAMETER DESCRIPTION
data

TYPE: Iterable[InputT]

Source code in edspdf/pipeline.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def initialize(self, data: Iterable[InputT]):
    """
    Initialize the components of the pipeline
    Each component must be initialized before the next components are run.
    Since a component might need the full training data to be initialized, all
    data may be fed to the component, making it impossible to enable batch caching.

    Therefore, we disable cache during the entire operation, so heavy computation
    (such as embeddings) that is usually shared will be repeated for each
    initialized component.

    Parameters
    ----------
    data: SupervisedData

    """
    # Component initialization
    print("Initializing components")
    data = multi_tee(data)

    with self.no_cache():
        for name, component in self.components.items():
            if not component.initialized:
                component.initialize(data)
                print(f"Component {repr(name)} initialized")

score(docs)

Scores a pipeline against a sequence of annotated documents

PARAMETER DESCRIPTION
docs

The documents to score

TYPE: Sequence[InputT]

RETURNS DESCRIPTION
Dict[str, Any]

A dictionary containing the metrics of the pipeline, as well as the speed of the pipeline. Each component that has a scorer will also be scored and its metrics will be included in the returned dictionary under a key named after each component.

Source code in edspdf/pipeline.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def score(self, docs: Sequence[InputT]):
    """
    Scores a pipeline against a sequence of annotated documents

    Parameters
    ----------
    docs: Sequence[InputT]
        The documents to score

    Returns
    -------
    Dict[str, Any]
        A dictionary containing the metrics of the pipeline, as well as the speed of
        the pipeline. Each component that has a scorer will also be scored and its
        metrics will be included in the returned dictionary under a key named after
        each component.
    """
    self.train(False)
    inputs: Sequence[InputT] = copy.deepcopy(docs)
    golds: Iterable[Dict[str, InputT]] = docs

    scored_components = {}

    # Predicting intermediate steps
    preds = defaultdict(lambda: [])
    for batch in batchify(
        tqdm(inputs, "Scoring components"), batch_size=self.batch_size
    ):
        self.reset_cache()
        for name, component in self.components.items():
            if component.scorer is not None:
                scored_components[name] = component
                batch = component.batch_process(batch)
                preds[name].extend(copy.deepcopy(batch))

    t0 = time.time()
    for _ in tqdm(self.pipe(inputs), "Scoring pipeline", total=len(inputs)):
        pass
    duration = time.time() - t0

    # Scoring
    metrics: Dict[str, Any] = {
        "speed": len(inputs) / duration,
    }
    for name, component in scored_components.items():
        metrics[name] = component.score(list(zip(preds[name], golds)))
    return metrics

preprocess(doc, supervision=False)

Runs the preprocessing methods of each component in the pipeline on a document and returns a dictionary containing the results, with the component names as keys.

PARAMETER DESCRIPTION
doc

The document to preprocess

TYPE: InputT

supervision

Whether to include supervision information in the preprocessing

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Dict[str, Any]
Source code in edspdf/pipeline.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def preprocess(self, doc: InputT, supervision: bool = False):
    """
    Runs the preprocessing methods of each component in the pipeline
    on a document and returns a dictionary containing the results, with the
    component names as keys.

    Parameters
    ----------
    doc: InputT
        The document to preprocess
    supervision: bool
        Whether to include supervision information in the preprocessing

    Returns
    -------
    Dict[str, Any]
    """
    prep = {}
    for name, component in self.components.items():
        if isinstance(component, TrainableComponent):
            prep[name] = component.preprocess(doc, supervision=supervision)
    return prep

preprocess_many(docs, compress=True, supervision=True)

Runs the preprocessing methods of each component in the pipeline on a collection of documents and returns an iterable of dictionaries containing the results, with the component names as keys.

PARAMETER DESCRIPTION
docs

TYPE: Iterable[InputT]

compress

Whether to deduplicate identical preprocessing outputs of the results if multiple documents share identical subcomponents. This step is required to enable the cache mechanism when training or running the pipeline over a tabular datasets such as pyarrow tables that do not store referential equality information.

DEFAULT: True

supervision

Whether to include supervision information in the preprocessing

DEFAULT: True

RETURNS DESCRIPTION
Iterable[OutputT]
Source code in edspdf/pipeline.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def preprocess_many(self, docs: Iterable[InputT], compress=True, supervision=True):
    """
    Runs the preprocessing methods of each component in the pipeline on
    a collection of documents and returns an iterable of dictionaries containing
    the results, with the component names as keys.

    Parameters
    ----------
    docs: Iterable[InputT]
    compress: bool
        Whether to deduplicate identical preprocessing outputs of the results
        if multiple documents share identical subcomponents. This step is required
        to enable the cache mechanism when training or running the pipeline over a
        tabular datasets such as pyarrow tables that do not store referential
        equality information.
    supervision: bool
        Whether to include supervision information in the preprocessing

    Returns
    -------
    Iterable[OutputT]
    """
    preprocessed = map(partial(self.preprocess, supervision=supervision), docs)
    if compress:
        return batch_compress_dict(preprocessed)
    return preprocessed

collate(batch, device=None)

Collates a batch of preprocessed samples into a single (maybe nested) dictionary of tensors by calling the collate method of each component.

PARAMETER DESCRIPTION
batch

The batch of preprocessed samples

TYPE: Dict[str, Any]

device

The device to move the tensors to before returning them

TYPE: Optional[torch.device] DEFAULT: None

RETURNS DESCRIPTION
Dict[str, Any]

The collated batch

Source code in edspdf/pipeline.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def collate(self, batch: Dict[str, Any], device: Optional[torch.device] = None):
    """
    Collates a batch of preprocessed samples into a single (maybe nested)
    dictionary of tensors by calling the collate method of each component.

    Parameters
    ----------
    batch: Dict[str, Any]
        The batch of preprocessed samples
    device: Optional[torch.device]
        The device to move the tensors to before returning them

    Returns
    -------
    Dict[str, Any]
        The collated batch
    """
    batch = decompress_dict(batch)
    if device is None:
        device = next(p.device for p in self.parameters())
    for name, component in self.components.items():
        if name in batch:
            component: TrainableComponent
            component_inputs = batch[name]
            batch[name] = component.collate(component_inputs, device)
    return batch

train(mode=True)

Enables training mode on pytorch modules

PARAMETER DESCRIPTION
mode

Whether to enable training or not

DEFAULT: True

Source code in edspdf/pipeline.py
366
367
368
369
370
371
372
373
374
375
376
377
def train(self, mode=True):
    """
    Enables training mode on pytorch modules

    Parameters
    ----------
    mode: bool
        Whether to enable training or not
    """
    for component in self.components.values():
        if hasattr(component, "train"):
            component.train(mode)

no_cache()

Disable caching for all (trainable) components in the pipeline

Source code in edspdf/pipeline.py
387
388
389
390
391
392
393
394
395
396
@contextmanager
def no_cache(self):
    """Disable caching for all (trainable) components in the pipeline"""
    saved = []
    for component in self.components.values():
        if isinstance(component, TrainableComponent):
            saved.append((component, component.enable_cache(False)))
    yield
    for component, do_cache in saved:
        component.enable_cache(do_cache)

parameters()

Returns an iterator over the Pytorch parameters of the components in the pipeline

Source code in edspdf/pipeline.py
398
399
400
401
402
403
404
405
406
407
408
def parameters(self):
    """Returns an iterator over the Pytorch parameters of the components in the
    pipeline"""
    seen = set()
    for component in self.components.values():
        if isinstance(component, torch.nn.Module):
            for param in component.parameters():
                if param in seen:
                    continue
                seen.add(param)
                yield param

__iter__()

Returns an iterator over the components in the pipeline.

Source code in edspdf/pipeline.py
431
432
433
def __iter__(self):
    """Returns an iterator over the components in the pipeline."""
    return iter(self.components.values())

__len__()

Returns the number of components in the pipeline.

Source code in edspdf/pipeline.py
435
436
437
def __len__(self):
    """Returns the number of components in the pipeline."""
    return len(self.components)