Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: initialize_q_batch does not always include the maximum value when called in batch mode #2772

Closed
1 task done
JackBuck opened this issue Mar 17, 2025 · 4 comments
Closed
1 task done
Labels
bug Something isn't working

Comments

@JackBuck
Copy link
Contributor

JackBuck commented Mar 17, 2025

What happened?

When calling initialize_q_batch() with a non-trivial batch_shape, the maximum of the input acq_values for each batch is not guaranteed to be included in the selected result. In fact, as the batch_shape is given more and more elements, the probability of including it shrinks to zero.

The code which is supposed to include the maximum is here:

# make sure we get the maximum
if max_idx not in idcs:
idcs[-1] = max_idx

The variable max_idx has shape batch_shape and idcs has shape n x batch_shape. The check max_idx not in idcs only checks that the maximum is in idcs for at least one batch. I would expect initialize_q_batch() to ensure the maximum of the input acq_values is included in every batch.

Please provide a minimal, reproducible example of the unexpected behavior.

import torch
from botorch.optim import initialize_q_batch

if __name__ == "__main__":
    torch.manual_seed(1234)

    X = torch.rand((20, 100, 1, 2))  # b x batch_shape x q x d
    Y = torch.sum(X**2, dim=[-2, -1])  # b x batch_shape

    true_max, true_max_idx = Y.max(dim=0)  # (batch_shape, batch_shape)

    X_init, acq_init = initialize_q_batch(X, acq_vals=Y, n=1)

    acq_init_max, _ = acq_init.max(dim=0)  # batch_shape
    mask = acq_init_max != true_max

    print(f"{mask.sum()} discrepancies")
    if mask.any():
        idx = torch.arange(X.shape[1])[mask][0]
        print(f"E.g. Batch index {idx}:")
        print(f"  Max input: {true_max[idx]} (index {true_max_idx[idx]})")
        print(f"  Max selected: {acq_init_max[idx]}")

Please paste any relevant traceback/logs produced by the example provided.

76 discrepancies
E.g. Batch index 2:
  Max input: 1.706216812133789 (index 1)
  Max selected: 1.3131837844848633

BoTorch Version

0.13.0

Python Version

3.13.2

Operating System

Ubuntu 20.04.6 LTS (Focal Fossa)

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct
@JackBuck
Copy link
Contributor Author

Potential fix at #2773

@Balandat
Copy link
Contributor

Thanks, this is nice find of a not-so-nice issue.

facebook-github-bot pushed a commit that referenced this issue Mar 17, 2025
…led in batch mode (#2773)

Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation
Fix for #2772

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #2773

Test Plan:
- Added assert to existing to ensure max value is included for every element of the batch.
  - Added a test case with large batch size
  - Verified test failed before making the change, and that it passes after

## Related PRs

No related PRs; no changes to docs required.

Reviewed By: Balandat

Differential Revision: D71320417

Pulled By: saitcakmak

fbshipit-source-id: d72a50cb5a6c9c3ecb672b9d78ca2e373fd87e04
@Balandat
Copy link
Contributor

@JackBuck can we consider this issue resolved?

@JackBuck
Copy link
Contributor Author

Hi @Balandat, sorry, yes - and thanks for the swift merging of the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants