Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured covariates don't allow for post-processing when loop=FALSE #1540

Closed
wds15 opened this issue Aug 25, 2023 · 3 comments
Closed

Structured covariates don't allow for post-processing when loop=FALSE #1540

wds15 opened this issue Aug 25, 2023 · 3 comments
Labels
Milestone

Comments

@wds15
Copy link
Contributor

wds15 commented Aug 25, 2023

When using structured covariates in non-linear models which have loop=FALSE set, then post-processing is broken. Here is a small example:

library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.20.1). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
set.seed(2134)
N <- 100
dat <- data.frame(y=rnorm(N))
dat$X <- matrix(rnorm(N*2), N, 2)

nlfun_stan <- "
  vector nlfun(vector a, vector b, vector c, matrix X) {
     vector[rows(a)] res;
     for(i in 1:rows(a)) {
         res[i] <- a[i] + b[i] * X[i,1] + c[i] * X[i,2];
      }
     return res;
  }
"
nlstanvar <- stanvar(scode = nlfun_stan, block = "functions")

# version for R post processing
nlfun <- function(a, b, c, X) {
  a + b * X[, , 1] + c * X[, , 2]
}

# fit the model
bform <- bf(y~nlfun(a, b, c, X), a~1, b~1, c~1, nl = TRUE, loop=FALSE)
fit <- brm(bform, dat, stanvars = nlstanvar, refresh=0)
#> Compiling Stan program...
#> Start sampling
summary(fit)
#>  Family: gaussian 
#>   Links: mu = identity; sigma = identity 
#> Formula: y ~ nlfun(a, b, c, X) 
#>          a ~ 1
#>          b ~ 1
#>          c ~ 1
#>    Data: dat (Number of observations: 100) 
#>   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#>          total post-warmup draws = 4000
#> 
#> Population-Level Effects: 
#>             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> a_Intercept    -0.02      0.10    -0.22     0.19 1.00     4600     2598
#> b_Intercept    -0.08      0.10    -0.28     0.11 1.00     4108     2977
#> c_Intercept     0.03      0.10    -0.17     0.24 1.00     4203     3152
#> 
#> Family Specific Parameters: 
#>       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> sigma     1.02      0.07     0.89     1.18 1.00     3853     3113
#> 
#> Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).

# fit benchmark model that should yield the same results up to MCMC error
fit2 <- brm(y~X, dat, refresh=0)
#> Compiling Stan program...
#> Start sampling
summary(fit2)
#>  Family: gaussian 
#>   Links: mu = identity; sigma = identity 
#> Formula: y ~ X 
#>    Data: dat (Number of observations: 100) 
#>   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#>          total post-warmup draws = 4000
#> 
#> Population-Level Effects: 
#>           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> Intercept    -0.02      0.10    -0.21     0.19 1.00     4598     3136
#> X1           -0.09      0.10    -0.27     0.10 1.00     4389     2822
#> X2            0.03      0.10    -0.18     0.23 1.00     4163     2645
#> 
#> Family Specific Parameters: 
#>       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> sigma     1.02      0.08     0.89     1.18 1.00     4213     2387
#> 
#> Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).

# post processing should work too... but it breaks
str(posterior_epred(fit))
#> Error in row(args[[i]]): a matrix-like object is required as argument to 'row'
loo(fit, fit2)
#> Error in row(args[[i]]): a matrix-like object is required as argument to 'row'

Created on 2023-08-25 with reprex v2.0.2

@paul-buerkner paul-buerkner added this to the 2.21.0 milestone Aug 28, 2023
@wds15
Copy link
Contributor Author

wds15 commented Aug 28, 2023

The issue can be worked around by the user as it seems with a little hack: Setting the loop attribute of the formula to TRUE puts back the post-processing into a working mode. So doing this

# setting the loop attibute to TRUE makes things work
attr(fit$formula$formula, "loop") <- TRUE
attr(fit2$formula$formula, "loop") <- TRUE

# now ok:
str(posterior_epred(fit))
loo(fit, fit2)

at the end of the above script will make stuff work ok as it looks to me. Would be great if @paul-buerkner could confirm that this will do the right thing as it seems that now all internal checks of brms work ok and things passed as needed to the user-defined function doing the simulation.

@wds15
Copy link
Contributor Author

wds15 commented Aug 28, 2023

Note that I really do not want to use loop=TRUE during fitting as I need more control over what happens (in fact I am hijacking the brms generated code such that I loop over patients in a reduce_sum approach - quite useful)...but I am ok with having loop=TRUE during simulation.

paul-buerkner added a commit that referenced this issue Sep 14, 2023
@paul-buerkner
Copy link
Owner

This should now be fixed on github. Note that, for use with loop = FALSE, the post-processing will run the non-linear function separately for each posterior draw, so you have to use to avoid further errors.

nlfun <- function(a, b, c, X) {
  a + b * X[, 1] + c * X[, 2]
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants