-
Notifications
You must be signed in to change notification settings - Fork 383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes to pyro model initialisation & sampling [WIP] #2695
Open
vitkl
wants to merge
8
commits into
scverse:main
Choose a base branch
from
vitkl:pyro_fixes
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a244ff3
fixes to pyro model init & sampling
vitkl 7127dc3
add missing tests
vitkl 53728cf
filter tensors by return_sites & exclude_vars
vitkl 368e3d5
detect valid sites to avoid missing to remove deterministic
vitkl 6af8857
bug fix
vitkl 77668a1
additional valid site filtering
vitkl 3bde45e
call both model and guide to create parameters in both
vitkl 051d7dd
Removing criticism
vitkl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
|
||
import numpy as np | ||
import torch | ||
from lightning.pytorch.callbacks import Callback | ||
from pyro import poutine | ||
|
||
from scvi import settings | ||
|
@@ -18,59 +17,17 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PyroJitGuideWarmup(Callback): | ||
"""A callback to warmup a Pyro guide. | ||
def setup_pyro_model(dataloader, pl_module): | ||
"""Way to warmup Pyro Model and Guide in an automated way. | ||
|
||
This helps initialize all the relevant parameters by running | ||
one minibatch through the Pyro model. | ||
Setup occurs before any device movement, so params are iniitalized on CPU. | ||
""" | ||
|
||
def __init__(self, dataloader: AnnDataLoader = None) -> None: | ||
super().__init__() | ||
self.dataloader = dataloader | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
"""Way to warmup Pyro Guide in an automated way. | ||
|
||
Also device agnostic. | ||
""" | ||
# warmup guide for JIT | ||
pyro_guide = pl_module.module.guide | ||
if self.dataloader is None: | ||
dl = trainer.datamodule.train_dataloader() | ||
else: | ||
dl = self.dataloader | ||
for tensors in dl: | ||
tens = {k: t.to(pl_module.device) for k, t in tensors.items()} | ||
args, kwargs = pl_module.module._get_fn_args_from_batch(tens) | ||
pyro_guide(*args, **kwargs) | ||
break | ||
|
||
|
||
class PyroModelGuideWarmup(Callback): | ||
"""A callback to warmup a Pyro guide and model. | ||
|
||
This helps initialize all the relevant parameters by running | ||
one minibatch through the Pyro model. This warmup occurs on the CPU. | ||
""" | ||
|
||
def __init__(self, dataloader: AnnDataLoader) -> None: | ||
super().__init__() | ||
self.dataloader = dataloader | ||
|
||
def setup(self, trainer, pl_module, stage=None): | ||
"""Way to warmup Pyro Model and Guide in an automated way. | ||
|
||
Setup occurs before any device movement, so params are iniitalized on CPU. | ||
""" | ||
if stage == "fit": | ||
pyro_guide = pl_module.module.guide | ||
dl = self.dataloader | ||
for tensors in dl: | ||
tens = {k: t.to(pl_module.device) for k, t in tensors.items()} | ||
args, kwargs = pl_module.module._get_fn_args_from_batch(tens) | ||
pyro_guide(*args, **kwargs) | ||
break | ||
for tensors in dataloader: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to do next(iter(dataloader)) to get a single batch. I think still having the class makes sense. Within this class, there can be a manual_start function. |
||
tens = {k: t.to(pl_module.device) for k, t in tensors.items()} | ||
args, kwargs = pl_module.module._get_fn_args_from_batch(tens) | ||
pl_module.module.guide(*args, **kwargs) | ||
pl_module.module.model(*args, **kwargs) | ||
break | ||
|
||
|
||
class PyroSviTrainMixin: | ||
|
@@ -177,7 +134,14 @@ def train( | |
|
||
if "callbacks" not in trainer_kwargs.keys(): | ||
trainer_kwargs["callbacks"] = [] | ||
trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) | ||
|
||
# Initialise pyro model with data | ||
from copy import copy | ||
|
||
dl = copy(data_splitter) | ||
dl.setup() | ||
dl = dl.train_dataloader() | ||
setup_pyro_model(dl, training_plan) | ||
|
||
runner = self._train_runner_cls( | ||
self, | ||
|
@@ -202,8 +166,9 @@ def _get_one_posterior_sample( | |
self, | ||
args, | ||
kwargs, | ||
return_sites: list | None = None, | ||
return_sites: list = None, | ||
return_observed: bool = False, | ||
exclude_vars: list = None, | ||
): | ||
"""Get one sample from posterior distribution. | ||
|
||
|
@@ -225,6 +190,11 @@ def _get_one_posterior_sample( | |
if isinstance(self.module.guide, poutine.messenger.Messenger): | ||
# This already includes trace-replay behavior. | ||
sample = self.module.guide(*args, **kwargs) | ||
# include and exclude requested sites | ||
if return_sites is not None: | ||
sample = {k: v for k, v in sample.items() if k in return_sites} | ||
if exclude_vars is not None: | ||
sample = {k: v for k, v in sample.items() if k not in exclude_vars} | ||
else: | ||
guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) | ||
model_trace = poutine.trace(poutine.replay(self.module.model, guide_trace)).get_trace( | ||
|
@@ -235,6 +205,9 @@ def _get_one_posterior_sample( | |
for name, site in model_trace.nodes.items() | ||
if ( | ||
(site["type"] == "sample") # sample statement | ||
and not ( | ||
name in exclude_vars if exclude_vars is not None else False | ||
) # exclude variables | ||
and ( | ||
(return_sites is None) or (name in return_sites) | ||
) # selected in return_sites list | ||
|
@@ -261,6 +234,7 @@ def _get_posterior_samples( | |
num_samples: int = 1000, | ||
return_sites: list | None = None, | ||
return_observed: bool = False, | ||
exclude_vars: list | None = None, | ||
show_progress: bool = True, | ||
): | ||
"""Get many (num_samples=N) samples from posterior distribution. | ||
|
@@ -284,7 +258,11 @@ def _get_posterior_samples( | |
dictionary {variable_name: [array with samples in 0 dimension]} | ||
""" | ||
samples = self._get_one_posterior_sample( | ||
args, kwargs, return_sites=return_sites, return_observed=return_observed | ||
args, | ||
kwargs, | ||
return_sites=return_sites, | ||
return_observed=return_observed, | ||
exclude_vars=exclude_vars, | ||
) | ||
samples = {k: [v] for k, v in samples.items()} | ||
|
||
|
@@ -296,7 +274,11 @@ def _get_posterior_samples( | |
): | ||
# generate new sample | ||
samples_ = self._get_one_posterior_sample( | ||
args, kwargs, return_sites=return_sites, return_observed=return_observed | ||
args, | ||
kwargs, | ||
return_sites=return_sites, | ||
return_observed=return_observed, | ||
exclude_vars=exclude_vars, | ||
) | ||
|
||
# add new sample | ||
|
@@ -365,6 +347,47 @@ def _get_obs_plate_sites( | |
|
||
return obs_plate | ||
|
||
def _get_valid_sites( | ||
self, | ||
args: list, | ||
kwargs: dict, | ||
return_observed: bool = False, | ||
): | ||
"""Automatically guess which model sites should be sampled. | ||
|
||
Parameters | ||
---------- | ||
args | ||
Arguments to the model. | ||
kwargs | ||
Keyword arguments to the model. | ||
return_observed | ||
Record samples of observed variables. | ||
|
||
Returns | ||
------- | ||
List with keys corresponding to site names. | ||
""" | ||
# find plate dimension | ||
trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) | ||
valid_sites = [ | ||
name | ||
for name, site in trace.nodes.items() | ||
if ( | ||
(site["type"] == "sample") # sample statement | ||
and ( | ||
( | ||
(not site.get("is_observed", True)) or return_observed | ||
) # don't save observed unless requested | ||
or (site.get("infer", False).get("_deterministic", False)) | ||
) # unless it is deterministic | ||
and not isinstance( | ||
site.get("fn", None), poutine.subsample_messenger._Subsample | ||
) # don't save plates | ||
) | ||
] | ||
return valid_sites | ||
|
||
@devices_dsp.dedent | ||
def _posterior_samples_minibatch( | ||
self, | ||
|
@@ -415,13 +438,17 @@ def _posterior_samples_minibatch( | |
self.to_device(device) | ||
|
||
if i == 0: | ||
return_observed = getattr(sample_kwargs, "return_observed", False) | ||
# get observation plate sites | ||
return_observed = sample_kwargs.get("return_observed", False) | ||
obs_plate_sites = self._get_obs_plate_sites( | ||
args, kwargs, return_observed=return_observed | ||
) | ||
if len(obs_plate_sites) == 0: | ||
# if no local variables - don't sample | ||
break | ||
# get valid sites & filter local sites | ||
valid_sites = self._get_valid_sites(args, kwargs, return_observed=return_observed) | ||
obs_plate_sites = {k: v for k, v in obs_plate_sites.items() if k in valid_sites} | ||
obs_plate_dim = list(obs_plate_sites.values())[0] | ||
|
||
sample_kwargs_obs_plate = sample_kwargs.copy() | ||
|
@@ -449,10 +476,10 @@ def _posterior_samples_minibatch( | |
i += 1 | ||
|
||
# sample global parameters | ||
valid_sites = self._get_valid_sites(args, kwargs, return_observed=return_observed) | ||
valid_sites = [v for v in valid_sites if v not in obs_plate_sites.keys()] | ||
sample_kwargs["return_sites"] = valid_sites | ||
global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) | ||
global_samples = { | ||
k: v for k, v in global_samples.items() if k not in list(obs_plate_sites.keys()) | ||
} | ||
|
||
for k in global_samples.keys(): | ||
samples[k] = global_samples[k] | ||
|
@@ -471,6 +498,7 @@ def sample_posterior( | |
batch_size: int | None = None, | ||
return_observed: bool = False, | ||
return_samples: bool = False, | ||
exclude_vars: list | None = None, | ||
summary_fun: dict[str, Callable] | None = None, | ||
): | ||
"""Summarise posterior distribution. | ||
|
@@ -531,6 +559,7 @@ def sample_posterior( | |
num_samples=num_samples, | ||
return_sites=return_sites, | ||
return_observed=return_observed, | ||
exclude_vars=exclude_vars, | ||
) | ||
|
||
param_names = list(samples.keys()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do those two classes exist in the first place?