temet.ML package
Submodules
temet.ML.common module
Common ML/pytorch functionality.
- train_model(dataloader, model, loss_fn, optimizer, batch_size, epoch_num, writer=None, verbose=True)
Train model for one epoch.
- test_model(dataloader, model, loss_fn, current_sample, acc_tol=None, writer=None, verbose=True)
Test model and compute statistics.
temet.ML.explore module
Misc ML exploration.
- class mnist_network
Bases:
Module
Simple NN to play with the MNIST Fashion dataset.
- forward(x)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- mnist_tutorial()
Playing with the MNIST Fashion dataset.
temet.ML.smhm module
Explorations: regression on stellar mass to halo mass (SMHM) relation.
- class SMHMDataset(simname, redshift, secondary_params=None)
Bases:
Dataset
A custom dataset for the stellar mass to halo mass (SMHM) relation. Stores samples (M_star) and their corresponding labels (M_halo).
- class mlp_network(hidden_size, num_inputs)
Bases:
Module
Simple NN to play with the mstar->mhalo problem.
- forward(x)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- train(hidden_size=8, verbose=True)
Train the SMHM MLP NN.
Explore the effect of the hidden layer size on the loss.
- plot_mstar_mhalo()
Plot the mstar->mhalo relation, ground truth vs trained model predictions.
- plot_mhalo_error_distribution()
Plot a histogram of (ground truth - trained model prediction) i.e. error on mhalo.
- plot_true_vs_predicted_mhalo(hidden_size=8)
Scatterplot of true vs predicted labels, versus the one-to-one (perfect) relation.
temet.ML.spectra module
Explorations: inference from mock spectra.
- class MockSpectraDataset(simname, redshift, ion, instrument, model_type, EW_minmax=None, SNR=None, num_noisy_per_sample=1, coldens=False)
Bases:
Dataset
A custom dataset for loading mock spectra and corresponding labels.
- class mlp_network(hidden_size, num_inputs, num_hidden_layers=1)
Bases:
Module
Simple MLP NN to explore the (normalized absorption spectra) -> (EW) mapping.
- forward(x)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cnn_network(kernel_size, hidden_size, num_inputs, num_hidden_layers=1)
Bases:
Module
Simple CNN to explore the (normalized absorption spectra) -> (EW) mapping.
- forward(x)
Forward pass.
- train(model_type='cnn', model_params=None, verbose=True)
Train the mockspec model.
- plot_true_vs_predicted(model_type='cnn', params=None)
Scatterplot of true vs predicted labels, versus the one-to-one (perfect) relation.
- run()
Driver.
Module contents
Machine learning explorations.