.. |br| raw:: html
How to train your MHN
=====================
If you want to learn a new MHN from mutation data, the :code:`optimizers` submodule
is likely where you should start. It currently contains *Optimizer* classes for training
a *classical* MHN (cMHN) (see `Schill et al. (2020) `_)
or an *observation* MHN (oMHN) (see `Schill et al. (2024) `_).
For an extensive demonstration of a simple MHN training and analysis workflow, have a look at `this demo notebook `_.
Configure the Optimizer
-----------------------
You can learn a new MHN from cross-sectional data with the :code:`Optimizer` class:
.. code-block:: python
from mhn.optimizers import Optimizer
opt = Optimizer()
By default, this class will train the most recent type of MHN. To train an older type,
you can specify it explicitly:
.. code-block:: python
# Example: training a classical MHN (cMHN) that does not account for the collider bias
opt = Optimizer(Optimizer.MHNType.cMHN)
We can specify the data that we want our MHN to be trained on:
.. code-block:: python
opt.load_data_matrix(data_matrix)
Here, :code:`data_matrix` can either be a *numpy* matrix or a *pandas* DataFrame, in which rows represent samples and
columns represent events.
If it is a *numpy* matrix, then you should set :code:`dtype=np.int32`, else you might get
a warning. |br|
Alternatively, if your training data is stored in a CSV file, you can call
.. code-block:: python
opt.load_data_from_csv(filename, delimiter)
where :code:`delimiter` is the delimiter separating the items in the CSV file (default: :code:`,`).
Internally, this method uses *pandas*' :code:`read_csv()` function to extract the data from the CSV file.
All additional keyword arguments given to this method will be passed on to that *pandas* function (see `read_csv() `_).
This means parameters like :code:`usecols` or :code:`skiprows` of the :code:`read_csv()` function
can also be used as parameters for this method:
.. code-block:: python
# loads data from a CSV file, but does not include rows 0 and 10
opt.load_data_from_csv(filename, delimiter, skiprows=[0, 10])
You can access the loaded data matrix with
.. code-block:: python
loaded_matrix = opt.training_data
If you work with a CUDA-capable device, you can choose which device you want to use to train a new MHN:
.. code-block:: python
# uses both CPU and GPU depending on the number of mutations in the individual sample (default)
opt.set_device(Optimizer.Device.AUTO)
# use the CPU to compute log-likelihood and gradient
opt.set_device(Optimizer.Device.CPU)
# use the GPU to compute log-likelihood and gradient
opt.set_device(Optimizer.Device.GPU)
# you can also access the Device enum directly with an Optimizer object
opt.set_device(opt.Device.AUTO)
You could also change the initial theta that is the starting point for training, which by default is the independence model
used by Schill et al. (2019), with
.. code-block:: python
opt.set_init_theta(init_theta)
If you want to regularly save the progress during training, you can use the :code:`save_progress()` method:
.. code-block:: python
# in this example we want to make a backup every 100 iterations
steps = 100
# we want to overwrite the previous backup file
always_new_file = False
# we want our backup file to be named 'mhn_training_backup.npy'
filename = 'mhn_training_backup.npy'
opt.save_progress(steps=steps, always_new_file=always_new_file, filename=filename)
You can also specify a callback function that is called after each training step:
.. code-block:: python
# In this example, we create a callback function that prints
# the current theta matrix after each training step.
# Ensure that your callback function accepts the theta matrix as a parameter;
# otherwise, it will raise an error.
def our_callback_function(theta: np.ndarray):
print(theta)
opt.set_callback_func(our_callback_function)
During training, a regularization penalty is applied to prevent overfitting. The
:code:`Optimizer` class currently supports three types: the L1-penalty (used by default), the L2-penalty, and
a custom symmetrical penalty that is further discussed in `Schill et al. (2024) `_. |br|
The following code snippet shows how to set a penalty:
.. code-block:: python
# for the L1-penalty, we set
opt.set_penalty(opt.Penalty.L1)
# for the L2-penalty, we set
opt.set_penalty(opt.Penalty.L2)
# for the symmetrical penalty, we set
opt.set_penalty(opt.Penalty.SYM_SPARSE)
Train a new MHN model
---------------------
Once your optimizer is configured, you can call the :code:`lambda_from_cv()` method
to find the best penalty strength ("lambda") for training by doing cross-validation. |br|
The :code:`lambda_from_cv()` method takes either a sequence of lambdas that should be tested or
the minimum, maximum and step size for potential lambda values. In the latter case,
the method will create a range of possible lambdas with logarithmic grid-spacing,
e.g. :code:`(0.0001, 0.0010, 0.0100, 0.1000)` for :code:`lambda_min=0.0001`,
:code:`lambda_max=0.1` and :code:`steps=4`. |br|
In this example, we opted for the latter option:
.. code-block:: python
import mhn
# use a seed to make the cross-validation results reproducible
mhn.set_seed(0)
cv_lambda = opt.lambda_from_cv(
lambda_min=1e-4, # the smallest lambda value evaluated
lambda_max=1e-1, # the largest lambda value evaluated
steps=4, # total number of lambda values evaluated
nfolds=5, # number of cross-validation folds
show_progressbar=True # show a progressbar during cross-validation
)
Finally, you can train a new MHN with
.. code-block:: python
opt.train(
lam=cv_lambda, # the lambda value used for regularization
maxit=5000, # the maximum number of training iterations
round_result=True, # round the resulting theta matrix to two decimal places
)
This function returns an :code:`MHN` object (see :ref:`here <*model*: A submodule containing the MHN classes>`), which contains the learned model
and provides additional methods for cancer progression analysis.
You can also access the learned model via the :code:`result` property:
.. code-block:: python
learned_mhn = opt.result
The documentation of all available optimizer classes can be found :ref:`here `.