Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assertion to split function ensuring valid windows #2587

Merged
merged 10 commits into from
Jan 26, 2023

Conversation

marcelkollovieh
Copy link
Contributor

Issue: #2577

Description of changes:
This PR addresses issue #2577 and added assertions for OffsetSplitter and DateSplitter to ensure that the targets of the label have valid target lengths, i.e. match prediction_length.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@lostella lostella added bug fix (one of pr required labels) pending v0.11.x backport This contains a fix to be backported to the v0.11.x branch labels Jan 24, 2023
Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably miss a few tests here @MarcelK1102 do you want to add them? This module is covered by https://github.com/awslabs/gluonts/blob/dev/test/dataset/test_split.py

Comment on lines 357 to 360
assert (
base + offset + prediction_length
<= entry[FieldName.TARGET].shape[-1]
), "Offset too short to generate windows"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rhs of the inequality is just label_slice.stop, you could use that I guess. Also maybe in this case the message should be "Date is too early to generate windows"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad: "Date is too late to generate windows" :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected :)

Comment on lines -309 to -310
if self.offset < 0 and offset_ >= 0:
offset_ += len(entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was clearly a bug, but I'm wondering whether offset_ += entry[FieldName.TARGET].shape[-1] is still needed here: if self.offset == -offset then offset_ == 0 then the input_slice below may be incorrect? It would be slice(None, 0) which would be empty I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not needed. If self.offset == -offset and offset_ == 0 then the assertion fails since -offset_ < prediction_length. If we would add offset_ += entry[FieldName.TARGET].shape[-1] then the assertion would fail as well since -offset_ < prediction_length again. I can't think of any scenario where it would change the outcome.

Copy link
Contributor

@lostella lostella Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I'm wondering though whether ensuring offset_ is always positive would render the logic simpler to read.
So,

if self.offset < 0:
    offset_ += entry[FieldName.TARGET].shape[-1]

assert offset_ + prediction_length <= entry[FieldName.TARGET].shape[-1], "Offset too short"

(didn't check this ☝️ thoroughly but makes sense to me)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two inequalities are equal then. I tested it, it works. I updated it accordingly.

@marcelkollovieh
Copy link
Contributor Author

I think we probably miss a few tests here @MarcelK1102 do you want to add them? This module is covered by https://github.com/awslabs/gluonts/blob/dev/test/dataset/test_split.py

@lostella Sure, I will do that!

@lostella lostella added this to the v0.12 milestone Jan 24, 2023
@lostella
Copy link
Contributor

Looks good now! Let's add the tests and then this can be merged!

@lostella lostella added the pending v0.12.x backport This contains a fix to be backported to the v0.12.x branch label Jan 26, 2023
@lostella lostella enabled auto-merge (squash) January 26, 2023 17:08
@lostella lostella merged commit 85b3dde into awslabs:dev Jan 26, 2023
lostella pushed a commit to lostella/gluonts that referenced this pull request Jan 30, 2023
@lostella lostella mentioned this pull request Jan 30, 2023
lostella pushed a commit to lostella/gluonts that referenced this pull request Jan 30, 2023
@lostella lostella mentioned this pull request Jan 30, 2023
lostella added a commit that referenced this pull request Jan 30, 2023
* Fix: avoid automatic device detection via serialized tensors when deserializing. (#2576)

* Make itertools Map/Filter dataclasses. (#2579)

* serde: Fix encoding of dtypes. (#2586)

* Add assertion to split function ensuring valid windows (#2587)

* Ensure dtype on feat_time in torch DeepAR. (#2596)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* fix requirements further

* fix style

* remove undesired change

---------

Co-authored-by: Shubham Kapoor <[email protected]>
Co-authored-by: Jasper <[email protected]>
Co-authored-by: MarcelK1102 <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
lostella added a commit that referenced this pull request Feb 2, 2023
* Add assertion to split function ensuring valid windows (#2587)

* Ensure dtype on feat_time in torch DeepAR. (#2596)

* Move NPTS back to `gluonts.model` (#2597)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Fix bug with static cardinalities in `PandasDataset` (#2599)

* Expose `weight_decay` in torch TFT estimator class (#2603)

* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* Fix incorrect import in `tsbench`, apply latest black (#2613)

* Allow ReduceLROnPlateau to track val_loss when validation set is available (#2614)

---------

Co-authored-by: MarcelK1102 <[email protected]>
Co-authored-by: Jasper <[email protected]>
Co-authored-by: Gerald Woo <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
@lostella lostella removed pending v0.11.x backport This contains a fix to be backported to the v0.11.x branch pending v0.12.x backport This contains a fix to be backported to the v0.12.x branch labels Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix (one of pr required labels)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants