Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer dims & coords from xarray variables passed to pm.Data #5791

Open
Tracked by #7053
michaelosthege opened this issue May 21, 2022 · 7 comments
Open
Tracked by #7053

Infer dims & coords from xarray variables passed to pm.Data #5791

michaelosthege opened this issue May 21, 2022 · 7 comments

Comments

@michaelosthege
Copy link
Member

Description

The internal function pm.data.determine_coords is called when pm.Data(..., export_index_as_coords=True) is invoked.

Right now it can only infer coords from pd.Series, or pd.DataFrame, but it would be nice (and actually quite simple) to do the same for xarray.DataArray variables.

After all, the pm.Data variables become xarray.DataArray again, when they're stored in the InferenceData.

Deliverables

  • Inferring dims from named pandas indexes
  • Inferring dims from DataArrays
  • Inferring coords from DataArrays
  • A unit test
  • Possibly a (backwards compatible with deprecation warning) rename of the export_index_as_coords to infer_dims_and_coords.
@michaelosthege michaelosthege changed the title Enhance determine_coords to infer dims & coords from xarray variables passed to pm.Data Infer dims & coords from xarray variables passed to pm.Data May 21, 2022
@mjhajharia
Copy link
Member

mjhajharia commented May 23, 2022

dim_name = value.index.name already gets the dimension names from index names, so would it suffice to just equate that to dims before returning, because this dims value name is already there in the coords that are being returned? (I might be missing stuff I've been away a while)

edit: this comment is only for the df part of the function

@michaelosthege
Copy link
Member Author

I suppose you're referring to these lines, which I recently edited in #5763 exactly to prepare for this issue.

pymc/pymc/data.py

Lines 511 to 515 in b5a5b56

if dims is None:
# TODO: Also determine dim names from the index
dims = [None] * np.ndim(value)
return coords, dims

But yes, you captured the gist of it.

We'll just have to remember to not override non-None elements in the user-provided dims.

The xarray support will be just another if block with hasattr or isinstance checks.

@michaelraczycki
Copy link
Contributor

is this still open? I'd like to take a look at this

@michaelosthege
Copy link
Member Author

is this still open? I'd like to take a look at this

Yes, feel free to take over

@michaelraczycki
Copy link
Contributor

@michaelosthege , could you briefly explain the naming convention here? Because in data.py it seems like we have a bunch of classes, named conventionally with CamelCase, but then we have standalone methods named with nouns, that initialize the "Data" objects. I know it's not meant to be OO but I can't seem to get a hold of which style it goes for

@michaelosthege
Copy link
Member Author

michaelosthege commented Feb 10, 2023

Yes, so the pm.Data container was introduced with CamelCase naming to make it look like other PyMC model variables/distributions like pm.Normal.
pm.ConstantData and pm.MutableData reproduced that.
This might also be a relict of v3 designs that actually had a "container" object.

Technically, since v4 the pm.Normal("n") is also not a pm.Normal object:

>>> import pymc as pm
>>> with pm.Model():
...     n = pm.Normal("n")
...
>>> isinstance(n, pm.Normal)
False
>>> type(n).__mro__
(<class 'pytensor.tensor.var.TensorVariable'>, <class 'pytensor.tensor.var._tensor_py_operators'>,
<class 'pytensor.graph.basic.Variable'>, <class 'pytensor.graph.basic.Node'>,
<class 'pytensor.graph.utils.MetaObject'>, <class 'typing.Generic'>,
<class 'object'>)
>>>

This is because of some pm.Distribution.__new__ magic that I don't understand.

There's a still open discussion about refactoring this to actual functions (like pm.Data), but the naming of the user-facing API will probably remain CamelCase for the foreseable future.

@michaelraczycki
Copy link
Contributor

Thanks! Also since the data.py has changed quite a lot since last updates on this issue I'll just re-write the whole functionality, from what I can see we're down to 200 lines of code less than what it was back then, and I guess it's just easier than adapting the legacy version of the file

twiecki pushed a commit that referenced this issue Feb 22, 2023
* added dim inference from xarray, deprecation warning and unittest for the new feature

* fixed typo in warning

* fixed accidental quotation around dim

* fixed failing assertions

* found and fixed cause of the failing test

* changed the coords assertion according to suggested form

* fixing mypy type missmatch

* working on getting the test to work

* removed typecasting to string on dim_name, was causing the mypy to fail

* took care locally of mypy errors

* Typo/formatting fixes

---------

Co-authored-by: Michal Raczycki <[email protected]>
Co-authored-by: Michael Osthege <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants