fresco.training package

Submodules

fresco.training.training module

Module for training a deep learning model>

class fresco.training.training.ModelTrainer(kw_args, model, dw, class_weights=None, device=None, fold=None, clc=False)[source]

Bases: object

Training class definition.

savepath

Path for saving models and metrics.

Type:

str

epochs

Maximum number of epochs to train for.

Type:

int

patience_stop

Patience stopping criteria.

Type:

int

tasks

List of tasks, each task is a string.

Type:

list

n_task

Are we using ntask?

Type:

bool

model

Model definition, declared and initialized in the caller.

Type:

Model

dw

DataHandler class, initialized in the caller.

Type:

DataHandler

device

CUDA or CPU.

Type:

torch.device

best_loss

Best validation loss scores.

Type:

float

patience_ctr

Patience counter.

Type:

int

loss

Loss value on device.

Type:

torch.tensor

y_preds

Dict of predictions, usually logits as torch.tensor, tasks are key values

Type:

dict

y_trues

Dict of ints with ground truth values, tasks are key values.

Type:

dict

multilabel

Multilabel classification?

Type:

bool

abstain

Use the deep abstaining classifier?

Type:

bool

mixed_precision

Use PyTorch automatic mixed precision?

Type:

bool

opt

Optimizer for training.

Type:

torch.optimizer

reduction

Type of reduction for the loss function.

Type:

str

loss_funs

Dictionary of torch loss functions for training each task.

Type:

dict

class_weights

List of floats for class weighting schemes.

Type:

list

compute_clc_loss(logits, y, idxs, dac=None, val=False)[source]

Compute forward pass and case level loss function.

Parameters:
  • batch (torch.tensor) – Iterate from DataLoader.

  • dac (Abstention) – Deep abstaining classifier class.

  • ntask_abs (float) – Probability of abstaining on the entire document.

Returns:

Float tensor.

Return type:

loss (torch.tensor)

Post-condition:

y_preds and y_trues are populated. loss is updated.

compute_clc_val_loss(batch, dac=None)[source]

Compute forward pass and loss function for case level context.

Parameters:
  • batch (torch.tensor) – Iterate from DataLoader.

  • dac (Abstention) – Deep abstaining classifier class.

  • ntask_abs (float) – Probability of abstaining on the entire document.

Returns:

Float tensor.

Return type:

loss (torch.tensor)

Post-condition:

y_preds and y_trues are populated. loss is updated.

compute_loss(logits, y, dac=None)[source]

Compute forward pass and loss function.

Parameters:
  • logits (torch.tensor) – Logits.

  • dac (Abstention) – Deep abstaining classifier class.

  • ntask_abs (float) – Probability of abstaining on the entire document.

Returns:

Float tensor.

Return type:

loss (torch.tensor)

Post-condition:

y_preds and y_trues are populated. loss is updated.

compute_scores(y_true, y_pred, task, dac=None)[source]

Compute macro/micro scores per task.

Parameters:
  • y_true (list) – List of ground truth labels (as int). It is a list of lists for n_tasks > 1.

  • y_pred (list) – List of predicted classes (as list of tensors). It is a list of lists for n_tasks > 1.

  • dac (Abstention) – Deep abstaining classifier class.

  • idx (int) – Indexing self.tasks.

Return type:

scores

compute_val_loss(batch, dac=None)[source]

Compute forward pass and loss function.

Parameters:
  • batch (torch.tensor) – Iterate from DataLoader.

  • dac (Abstention) – Deep abstaining classifier class.

  • ntask_abs (float) – Probability of abstaining on the entire document.

Returns:

Float tensor.

Return type:

loss (torch.tensor)

Post-condition:

y_preds and y_trues are populated. loss is updated.

fit_model(train_loader, val_loader=None, dac=None)[source]

Main training loop.

Parameters:
  • train_loader (torch.DataLoader) – Initialized and populated in the calling function.

  • val_loader (torch.DataLoader) – Initialized and populated in the calling function. Can be None, then training is used for validation scores.

  • dac (Abstention) – Abstention class, deep abstaining classifier class.

get_ys(logits, y, idx, val=False)[source]

Get ground truth and y_predictions.

Parameters:
  • logits (list) – List of logits.

  • y (dict) – Dictionary of numpy ndarrays, with tasks as keys.

  • idx (int) – Index in the enumerated DataLoader.

Post-condition:

self.y_trues and self.y_preds are populated.

output_scores(scores, abs_scores=None, dac=None, stop_vals=None, alpha_scale=None, ntask_stop_val=None, ntask_alpha_scale=None)[source]

Print stats to the terminal.

Parameters:
  • scores (dict) – Dictionary of metrics, with tasks as keys and metrics as values.

  • dac (Abstaining) – Abstaining Classifier class.

  • stop_vals (list) – List of stopping criteria.

  • alpha_scale (dict) – Dictionary of scaling values for the dac, with tasks as keys.

  • ntask_stop_val (float) – Stopping criterion for ntask.

  • ntask_alpha_scale (float) – Scaling factor for ntask alpha.

print_stats(scores)[source]

Print macro/micro scores to stdout.

profile_fit_model(train_loader, dac=None)[source]

Main training loop.

Parameters:
  • train_loader (torch.DataLoader) – Initialized and populated in the calling function.

  • dac (Abstention) – Abstention class, deep abstaining classifier class.

score(epoch, val_loader=None, dac=None)[source]

Score a model during training.

Parameters:
  • epoch (int) – Epoch number.

  • val_loader (torch.dataLoader) – DataLoader class with PathReports.

  • dac (Abstention) – Deep abstaining classifier class. Are we using the dac?

stop_metrics(loss, epoch, stop_metrics=None, dac=None)[source]

Compute stop metrics.

Compute stop metrics for normal (val loss + patience) or DAC (stop metric) training.

Parameters:
  • loss (torch.tensor) – Float tensor, val loss value at the current epoch.

  • epoch (int) – Epoch counter.

  • stop_metrics (dict) – Dictionary of floats for abstention rates and accuracy.

Returns:

stop_val (float): DAC stopping criterion.

stop (bool): Stop or go?

Otherwise:

stop_val (float): Best val loss. stop (bool): Stop or go?

Return type:

If abstaining

Post-condition: If not abstaining, patience counter is updated.

train_metrics(dac=None)[source]

Compute per-epoch metrics during training.

fresco.training.training.trace_handler(p)[source]

Module contents