Source code for pyspectools.models.torch_models


from pathlib import Path
from abc import abstractmethod

import torch
from torch import nn
from torch.nn import functional as F


[docs]class GenericModel(nn.Module): def __init__(self): super().__init__()
[docs] @classmethod def load_weights(cls, weights_path=None, device=None, **kwargs): """ Convenience method for loading in the weights of a model. Basically initializes the model, and wraps a `torch.load` with automatic cuda/cpu detection. Parameters ---------- weights_path : str String path to the trained weights of a model; typically with extension .pt device : str String reference to the target device, either "cpu", "cuda", or a specific CUDA device (e.g. "cuda:0"). If None (default) the model will be loaded onto a GPU if available, otherwise a CPU. kwargs are passed into the creation of the model, allowing you to set different parameters. Returns ------- model Instance of the PyTorch model with loaded weights """ # default location for weights is the package directory, # along with the model name if weights_path is None: pkg_dir = Path(__file__).parent weights_path = pkg_dir.joinpath(f"{cls.__name__}.pt") if not device: if torch.cuda.is_available(): device = "cuda" else: device = "cpu" model = cls(**kwargs) model.device = device model.load_state_dict(torch.load(weights_path, map_location=device)) return model
[docs] def init_layers(self, weight_func=None, bias_func=None): """ Function that will initialize all the weights and biases of the model layers. This function uses the `apply` method of `Module`, and so will only work on layers that are contained as children. Parameters ---------- weight_func : `nn.init` function, optional Function to use to initialize weights, by default None which will default to `nn.init.xavier_normal` bias_func : `nn.init` function, optional Function to use to initialize biases, by default None which will default to `nn.init.xavier_uniform` """ # Apply initializers to all of the Module's children with `apply` self.apply(self._initialize_wb)
def _initialize_wb(self, layer: nn.Module): """ Static method for applying an initializer to weights and biases. If a layer is passed without weight and bias attributes, this function will effectively ignore it. Parameters ---------- layer : `nn.Module` Layer that is a subclass of `nn.Module` """ if isinstance(layer, nn.Linear): torch.nn.init.kaiming_normal_(layer.weight.data) if layer.bias is not None: torch.nn.init.uniform_(layer.bias.data, a=-1., b=1.) elif isinstance(layer, nn.LSTM): torch.nn.init.xavier_uniform_(layer.weight_hh_l0) torch.nn.init.xavier_uniform_(layer.weight_ih_l0) if layer.bias_ih_l0 is not None: torch.nn.init.zeros_(layer.bias_ih_l0) if layer.bias_hh_l0 is not None: torch.nn.init.zeros_(layer.weight_hh_l0) if isinstance(layer, nn.BatchNorm1d): torch.nn.init.ones_(layer.weight.data) if layer.bias is not None: torch.nn.init.zeros_(layer.bias.data) def __len__(self): return self.get_num_parameters()
[docs] def get_num_parameters(self) -> int: """ Calculate the number of parameters contained within the model. Returns ------- int Number of trainable parameters """ return sum([p.numel() for p in self.parameters() if p.requires_grad])
[docs] @abstractmethod def compute_loss(self): pass
def _reparametrize(self, mu: torch.Tensor, logvar: torch.Tensor): """ Private method for scale/shift operation on a unit Gaussian (N~[0,1]) using the parameterized mu and logvar in a variational model. Returns the latent encoding based on these values. Parameters ---------- mu : torch.Tensor Tensor of Gaussian centers. logvar : torch.Tensor Tensor of log variance Returns ------- z Torch Tensor corresponding to the latent embedding """ std = logvar.exp().sqrt() eps = torch.autograd.Variable(torch.randn_like(std)) return eps.mul(std).add(mu)
[docs]class VariationalSpecDecoder(GenericModel): """ Uses variational inference to capture the uncertainty with respect to Coulomb matrix eigenvalues. Instead of using dropout, this model represents uncertainty via a probabilistic latent layer. """ __name__ = "VariationalSpecDecoder" def __init__( self, latent_dim=14, output_dim=30, alpha=0.8, dropout=0.2, optimizer=None, loss_func=None, opt_settings=None, param_transform=None, tracker=True, ): super().__init__() self.mu_dense = nn.Linear(12, latent_dim) self.logvar_dense = nn.Linear(12, latent_dim) # output should all be positive self.spec_decoder = nn.Sequential( nn.Linear(latent_dim, 128), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(128, 256), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(256, output_dim), nn.ReLU(inplace=True), ) self.input_drop = nn.Dropout(dropout)
[docs] def forward(self, X: torch.Tensor): """ Inputs for this model is a single Tensor, where each row is 12 elements long (8 constants, one-hot encoding for composition). The idea behind this is to predict the eigenspectrum conditional on the molecular composition. Parameters ---------- X : torch.Tensor Tensor containing spectroscopic constants, and one-hot encoding of the composition. Returns ------- output, mu, logvar The predicted eigenspectrum, and the latent parameters mu and logvar """ mu, logvar = self.mu_dense(X), -F.relu(self.logvar_dense(X)) z = self._reparametrize(mu, logvar) output = self.spec_decoder(z) return output, mu, logvar
[docs] def compute_loss(self, X: torch.Tensor, Y: torch.Tensor): """ Calculate the loss of this model as the combined prediction error and KL-divergence from the approximate posterior. Parameters ---------- X : torch.Tensor Combined tensor of the spectroscopic constants and the one-hot encoded composition. Y : torch.Tensor Target eigenspectrum Returns ------- torch.Tensor Joint loss of MSE and KL divergence """ pred_Y, mu, logvar = self.forward(X) accuracy = F.mse_loss(pred_Y, Y, reduction="sum") var = logvar.exp() # The summation is performed over the encoding length, as according to Kingma KL = -0.5 * torch.sum(1 + 2.0 * logvar - mu.pow(2.0) - var.pow(2.0)) return accuracy + KL
[docs]class VariationalDecoder(GenericModel): """ This model uses the intermediate eigenspectrum to calculate a latent embedding that is then used to predict the molecular formula and functional groups. You can think of the first action as "re-encoding", but the driving principle is that an eigenspectrum could map onto various structures, even when conditional on the composition. """ __name__ = "VariationalDecoder" def __init__( self, latent_dim=14, eigen_length=30, nclasses=23, alpha=0.8, dropout=0.2, loss_func=None, param_transform=None, tracker=True, ): super().__init__() self.mu_dense = nn.Linear(eigen_length + 4, latent_dim) self.logvar_dense = nn.Linear(eigen_length + 4, latent_dim) self.formula_decoder = nn.Sequential( nn.Linear(latent_dim, 128), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(128, 64), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(64, 4), nn.ReLU(inplace=True), ) self.functional_classifier = nn.Sequential( nn.Linear(latent_dim, 128), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(128, 256), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(256, 64), nn.Dropout(dropout), nn.LeakyReLU(alpha), nn.Linear(64, nclasses), nn.LogSigmoid(), )
[docs] def forward(self, X: torch.Tensor): """ Perform a forward pass of the VariationalDecoder model. This takes the concatenated input of the eigenspectrum and the one-hot composition, produces a latent embedding that is then used to predict the formula and functional group classification. Parameters ---------- X : torch.Tensor [description] Returns ------- formula_output Nx4 tensor corresponding to the number of atoms in the [H,C,O,N] positions. functional_output Nx23 tensor corresponding to multilabel classification, provided as log sigmoid. mu, logvar Latent variables of the variational layer """ mu, logvar = self.mu_dense(X), -F.relu(self.logvar_dense(X)) # generate latent representation z = self._reparametrize(mu, logvar) formula_output = self.formula_decoder(z) functional_output = self.functional_classifier(z) return formula_output, functional_output, mu, logvar
[docs] def compute_loss( self, X: torch.Tensor, formula: torch.Tensor, groups: torch.Tensor ): """ Calculate the joint loss of this model. This corresponds to the sum of three components: a KL-divergence loss for the variational layer, a formula prediction accuracy as the MSE loss, and the BCE loss for the multilabel classification for the functional group prediction. Parameters ---------- X : torch.Tensor [description] formula : torch.Tensor Length of the formula encoding, typically 4 [H,C,O,N] groups : torch.Tensor Length of the functional groups encoding. """ pred_formula, pred_func, mu, logvar = self.forward(X) # Predict atom number accuracy = F.mse_loss(pred_formula, formula, reduction="sum") # Multilabel classification classification = F.binary_cross_entropy_with_logits(pred_func, groups, reduction="sum") # calculate the divergence term var = logvar.exp() # The summation is performed over the encoding length, as according to Kingma KL = -0.5 * torch.sum(1 + 2.0 * logvar - mu.pow(2.0) - var.pow(2.0)) return accuracy + KL + classification
[docs]class VarMolDetect(GenericModel): """ Umbrella model that encapsulates the full set of variational models. The premise is to more or less try to do end-to-end learning, and should meet the user half-way in terms of usability. The `forward` method takes the spectroscopic constants and the molecular composition as separate inputs, and performs the concatenation prior to any calculation. The composition is reused by the `VariationalDecoder` model. """ __name__ = "VariationalMoleculeDetective" def __init__( self, eigen_length=30, latent_dim=14, nclasses=23, alpha=0.8, dropout=0.2, tracker=True, ): super().__init__() self.norm = nn.BatchNorm1d(8) self.spectrum_decoder = VariationalSpecDecoder(latent_dim, eigen_length, alpha, dropout=dropout) self.joint_decoder = VariationalDecoder( latent_dim, eigen_length, nclasses, alpha, dropout=dropout ) self.input_dropout = nn.Dropout(dropout) for name, parameter in self.named_parameters(): if "logvar" in name and "weight" in name: # nn.init.uniform_(parameter, a=-10., b=-8.) nn.init.kaiming_normal_(parameter) elif "bias" in name: nn.init.zeros_(parameter) elif "weight" in name and "norm" not in name: nn.init.kaiming_normal_(parameter) elif "weight" in name and "norm" in name: nn.init.ones_(parameter)
[docs] def forward(self, constants: torch.Tensor, composition: torch.Tensor): # This mask ensures that predictions of formulae are appropriate # for the specified composition. The conditional estimation alone # was okay, but could sometimes still predict formulae it shouldn't comp_mask = torch.FloatTensor( [ [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 0, 1], [1, 1, 1, 1] ] ).to(constants.device) comp_encoding = composition.argmax(dim=-1) masks = torch.autograd.Variable(comp_mask[comp_encoding], requires_grad=True) # Run batch norm on A,B,C constants = self.input_dropout(self.norm(constants)) # concatenate the inputs X = torch.cat((constants, composition), dim=-1) # compute the eigenspectrum conditional on composition eigen, eigen_mu, eigen_logvar = self.spectrum_decoder(X) # run the decoders to predict properties, conditional on the composition eigen_composition = torch.cat((eigen, composition), dim=-1) formula, functionals, decode_mu, decode_logvar = self.joint_decoder( eigen_composition ) # Remove predictions of atoms that don't belong formula = formula * masks return ( (eigen, formula, functionals), (eigen_mu, eigen_logvar, decode_mu, decode_logvar), )
[docs] def compute_loss(self, constants, composition, eigenspectrum, formula, functionals): # run through the models predictions, latents = self.forward(constants, composition) # unpack the predictions, and compute their losses pred_eigen, pred_formula, pred_func = predictions # for regression, we take the log10 for stability; everything done in place prediction_loss = F.mse_loss(pred_eigen, eigenspectrum, reduction="mean") prediction_loss.add_( F.mse_loss(pred_formula, formula, reduction="mean") ) prediction_loss.add_( F.binary_cross_entropy_with_logits(pred_func, functionals, reduction="mean") ) # now for the variational losses eigen_mu, eigen_logvar, decode_mu, decode_logvar = latents eigen_var = eigen_logvar.exp() decode_var = decode_logvar.exp() # The summation is performed over the encoding length, as according to Kingma kl_loss = -0.5 * torch.sum( 1 + 2.0 * eigen_logvar - eigen_mu.pow(2.0) - eigen_var.pow(2.0) ) kl_loss /= constants.size(0) * eigen_logvar.size(0) kl_loss.add_( -0.5 * torch.sum( 1 + 2.0 * decode_logvar - decode_mu.pow(2.0) - decode_var.pow(2.0) ) ) kl_loss /= constants.size(0) * decode_logvar.size(0) return prediction_loss + kl_loss