Skip to content

ModelManager

automil.model.ModelManager provides model management utilities and is responsible for instantiating MIL models and validating hyperparameters against model-specific contraints and limits

ModelManager

Manages the instantiation and configuration of MIL models.

This class provides
  • An interface for creating automil supported MIL models
  • Model-specific hyperparameter validation and adjustments should they be outside of recommended model-limits
  • dummy input generation for debugging and validation
Source code in automil/model.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
class ModelManager:
    """
    Manages the instantiation and configuration of MIL models.

    This class provides:
        - An interface for creating automil supported MIL models
        - Model-specific hyperparameter validation and adjustments should they be outside of recommended model-limits
        - dummy input generation for debugging and validation
    """
    # Baseline internal model configurations
    _MODEL_CONFIGS: dict[ModelType, ModelConfig] = {
        ModelType.Attention_MIL: ModelConfig(
            model_cls=Attention_MIL,
            slideflow_model_name="attention_mil",
            input_params={"input_dim": "n_feats", "num_classes": "n_out"},
            min_lr=1e-5,
            max_lr=1e-4,
            max_batch_size=MAX_BATCH_SIZE,
            max_tiles_per_bag=1000, # Can be quite large for Attention_MIL
        ),
        # We use smaller batch sizes for TransMIL due to memory constraints
        ModelType.TransMIL: ModelConfig(
            model_cls=TransMIL,
            slideflow_model_name="transmil",
            input_params={"input_dim": "n_feats", "num_classes": "n_out"},
            min_lr=1e-5,
            max_lr=1e-4,
            max_batch_size=32,
            max_tiles_per_bag=500,
        ),
        ModelType.BistroTransformer: ModelConfig(
            model_cls=BistroTransformer,
            slideflow_model_name="bistro_transformer",
            input_params={"input_dim": "dim", "num_classes": "heads"},
            min_lr=1e-5,
            max_lr=1e-4,
            max_batch_size=MAX_BATCH_SIZE,
            max_tiles_per_bag=1000,
        )
    }

    def __init__(self, model_type: ModelType) -> None:
        f"""Instantiates a ModelManager object

        Args:
            model_type (ModelType): Type of model to instantiate. Can be one of: {
                [model.name for model in ModelType]
            }
        """
        self.model_type = model_type
        self.config = self._MODEL_CONFIGS[model_type]

    @property
    def slideflow_name(self) -> str:
        """
        Slideflow-internal identifier for the managed model.

        Returns:
            str:
                Slideflow model name.
        """
        return self.config.slideflow_model_name

    @property
    def model_class(self) -> type[nn.Module]:
        """
        Corresponding python class implementing the model.

        Returns:
            type[nn.Module]:
                Model class.
        """
        return self.config.model_cls

    def create_model(self, input_dim: int = 1024, num_classes: int = 2, **kwargs) -> nn.Module:
        """Instantiates the model with validated hyperparameters.

        Args:
            input_dim (int, optional): Feature dimensions. Defaults to 1024.
            num_classes (int, optional): Number of classes. Defaults to 2.

        Returns:
            nn.Module: Instantiated model
        """
        # Map standardized parameter names to model-specific names
        model_params = {
            self.config.input_params["input_dim"]: input_dim,
            self.config.input_params["num_classes"]: num_classes,
        }
        # Update with remaining kwargs
        model_params.update(kwargs)

        # Fallback: If instantiation fails, try with defaults
        try:
            return self.model_class(**model_params)
        except TypeError as e:
            return self.model_class()

    def create_dummy_input(
        self, 
        batch_size: int, 
        tiles_per_bag: int, 
        input_dim: int
    ) -> tuple:
        """Creates an appropriate dummy input for the model

        Dummy input tensors can be used for a variety of tasks. Primarily they are used
        to perform `dry runs`, for example to measure the memory reservation of a model instance

        Args:
            batch_size: Number of samples in batch
            tiles_per_bag: Number of tiles per bag
            input_dim: Feature dimension

        Returns:
            Tuple of tensors to pass to model forward()
        """
        match self.model_type:

            case ModelType.Attention_MIL:
                # Both expect a lens tensor in addition to input
                dummy_input = torch.randn(batch_size, tiles_per_bag, input_dim).cuda()
                lens = torch.tensor([tiles_per_bag] * batch_size).cuda()
                return (dummy_input, lens)

            case ModelType.TransMIL | ModelType.BistroTransformer:
                # BistroTransformer expects only input (no lens)
                dummy_input = torch.randn(batch_size, tiles_per_bag, input_dim).cuda()
                return (dummy_input,)


    def validate_hyperparameters(self, lr: float, batch_size: int, max_tiles_per_bag: int) -> dict[str, float | int]:
        """
        Validates a set of hyperparameters against model-specific constraints.

        Args:
            lr (float):
                Learning rate.
            batch_size (int):
                Batch size.
            max_tiles_per_bag (int):
                Maximum tiles per bag.

        Returns:
            dict[str, float | int]:
                Suggested parameter adjustments for out-of-range values.
        """
        suggestions = {}

        # TODO | Better tuning logic / strategy (probably for all but definitely for lr)
        if not (self.config.min_lr <= lr <= self.config.max_lr):
            suggestions["lr"] = (self.config.min_lr + self.config.max_lr) / 2

        if batch_size > self.config.max_batch_size:
            suggestions["batch_size"] = self.config.max_batch_size

        if max_tiles_per_bag > self.config.max_tiles_per_bag:
            suggestions["max_tiles_per_bag"] = self.config.max_tiles_per_bag

        return suggestions

    @classmethod
    def compare_models(cls) -> str:
        """Generates a comparison table for all available models

        Returns:
            str: A comparison table as string
        """
        from tabulate import tabulate

        table = []
        for model_type, config in cls._MODEL_CONFIGS.items():
            table.append([
                model_type.name,
                config.slideflow_model_name,
                config.max_batch_size,
                config.max_tiles_per_bag,
                f"{config.min_lr:.0e}-{config.max_lr:.0e}",
            ])

        headers = [
            "Model Type", "Slideflow Name", "Max Batch Size", 
            "Max Tiles Per Bag", "LR Range"
        ]

        return tabulate(table, headers=headers, tablefmt="fancy_outline")

model_class property

model_class: type[Module]

Corresponding python class implementing the model.

Returns:

Type Description
type[Module]

type[nn.Module]: Model class.

slideflow_name property

slideflow_name: str

Slideflow-internal identifier for the managed model.

Returns:

Name Type Description
str str

Slideflow model name.

compare_models classmethod

compare_models() -> str

Generates a comparison table for all available models

Returns:

Name Type Description
str str

A comparison table as string

Source code in automil/model.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
@classmethod
def compare_models(cls) -> str:
    """Generates a comparison table for all available models

    Returns:
        str: A comparison table as string
    """
    from tabulate import tabulate

    table = []
    for model_type, config in cls._MODEL_CONFIGS.items():
        table.append([
            model_type.name,
            config.slideflow_model_name,
            config.max_batch_size,
            config.max_tiles_per_bag,
            f"{config.min_lr:.0e}-{config.max_lr:.0e}",
        ])

    headers = [
        "Model Type", "Slideflow Name", "Max Batch Size", 
        "Max Tiles Per Bag", "LR Range"
    ]

    return tabulate(table, headers=headers, tablefmt="fancy_outline")

create_dummy_input

create_dummy_input(
    batch_size: int, tiles_per_bag: int, input_dim: int
) -> tuple

Creates an appropriate dummy input for the model

Dummy input tensors can be used for a variety of tasks. Primarily they are used to perform dry runs, for example to measure the memory reservation of a model instance

Parameters:

Name Type Description Default
batch_size int

Number of samples in batch

required
tiles_per_bag int

Number of tiles per bag

required
input_dim int

Feature dimension

required

Returns:

Type Description
tuple

Tuple of tensors to pass to model forward()

Source code in automil/model.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def create_dummy_input(
    self, 
    batch_size: int, 
    tiles_per_bag: int, 
    input_dim: int
) -> tuple:
    """Creates an appropriate dummy input for the model

    Dummy input tensors can be used for a variety of tasks. Primarily they are used
    to perform `dry runs`, for example to measure the memory reservation of a model instance

    Args:
        batch_size: Number of samples in batch
        tiles_per_bag: Number of tiles per bag
        input_dim: Feature dimension

    Returns:
        Tuple of tensors to pass to model forward()
    """
    match self.model_type:

        case ModelType.Attention_MIL:
            # Both expect a lens tensor in addition to input
            dummy_input = torch.randn(batch_size, tiles_per_bag, input_dim).cuda()
            lens = torch.tensor([tiles_per_bag] * batch_size).cuda()
            return (dummy_input, lens)

        case ModelType.TransMIL | ModelType.BistroTransformer:
            # BistroTransformer expects only input (no lens)
            dummy_input = torch.randn(batch_size, tiles_per_bag, input_dim).cuda()
            return (dummy_input,)

create_model

create_model(
    input_dim: int = 1024, num_classes: int = 2, **kwargs
) -> nn.Module

Instantiates the model with validated hyperparameters.

Parameters:

Name Type Description Default
input_dim int

Feature dimensions. Defaults to 1024.

1024
num_classes int

Number of classes. Defaults to 2.

2

Returns:

Type Description
Module

nn.Module: Instantiated model

Source code in automil/model.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def create_model(self, input_dim: int = 1024, num_classes: int = 2, **kwargs) -> nn.Module:
    """Instantiates the model with validated hyperparameters.

    Args:
        input_dim (int, optional): Feature dimensions. Defaults to 1024.
        num_classes (int, optional): Number of classes. Defaults to 2.

    Returns:
        nn.Module: Instantiated model
    """
    # Map standardized parameter names to model-specific names
    model_params = {
        self.config.input_params["input_dim"]: input_dim,
        self.config.input_params["num_classes"]: num_classes,
    }
    # Update with remaining kwargs
    model_params.update(kwargs)

    # Fallback: If instantiation fails, try with defaults
    try:
        return self.model_class(**model_params)
    except TypeError as e:
        return self.model_class()

validate_hyperparameters

validate_hyperparameters(
    lr: float, batch_size: int, max_tiles_per_bag: int
) -> dict[str, float | int]

Validates a set of hyperparameters against model-specific constraints.

Parameters:

Name Type Description Default
lr float

Learning rate.

required
batch_size int

Batch size.

required
max_tiles_per_bag int

Maximum tiles per bag.

required

Returns:

Type Description
dict[str, float | int]

dict[str, float | int]: Suggested parameter adjustments for out-of-range values.

Source code in automil/model.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def validate_hyperparameters(self, lr: float, batch_size: int, max_tiles_per_bag: int) -> dict[str, float | int]:
    """
    Validates a set of hyperparameters against model-specific constraints.

    Args:
        lr (float):
            Learning rate.
        batch_size (int):
            Batch size.
        max_tiles_per_bag (int):
            Maximum tiles per bag.

    Returns:
        dict[str, float | int]:
            Suggested parameter adjustments for out-of-range values.
    """
    suggestions = {}

    # TODO | Better tuning logic / strategy (probably for all but definitely for lr)
    if not (self.config.min_lr <= lr <= self.config.max_lr):
        suggestions["lr"] = (self.config.min_lr + self.config.max_lr) / 2

    if batch_size > self.config.max_batch_size:
        suggestions["batch_size"] = self.config.max_batch_size

    if max_tiles_per_bag > self.config.max_tiles_per_bag:
        suggestions["max_tiles_per_bag"] = self.config.max_tiles_per_bag

    return suggestions

Helpers

create_model_instance

create_model_instance(
    model_type: ModelType, input_dim: int, n_out: int = 2
) -> nn.Module

Safely creates a model instance with the correct parameters.

This method instantiates a model corresponding to the provided :class:ModelType with the specified input and output dimensions

Parameters:

Name Type Description Default
model_type ModelType

The ModelType enum

required
input_dim int

Input feature dimension

required
n_out int

Number of output classes

2

Returns:

Type Description
Module

Instantiated model

Source code in automil/model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def create_model_instance(
    model_type: ModelType,
    input_dim: int,
    n_out: int = 2
) -> nn.Module:
    """Safely creates a model instance with the correct parameters.

    This method instantiates a model corresponding to the provided
    :class:`ModelType` with the specified input and output dimensions

    Args:
        model_type: The ModelType enum
        input_dim: Input feature dimension
        n_out: Number of output classes

    Returns:
        Instantiated model
    """
    try:
        match model_type:

            case ModelType.Attention_MIL:
                model_cls = Attention_MIL
                return model_cls(n_feats=input_dim, n_out=n_out)

            case ModelType.TransMIL:
                model_cls = TransMIL
                return model_cls(n_feats=input_dim, n_out=n_out)

            case ModelType.BistroTransformer:
                model_cls = BistroTransformer
                return model_cls(dim=input_dim)

            case _:
                return model_cls()
    except Exception as e:
        slideflow_log.error(f"Error while creating model instance: {e}")
        raise e