epinowcast
Loading...
Searching...
No Matches
delay_multinomial_lpmf.stan
Go to the documentation of this file.
1
14real delay_multinomial_lpmf(array[] int obs, vector log_exp_obs) {
15 return multinomial_logit_lpmf(obs | log_softmax(log_exp_obs));
16}
17
44real delay_multinomial_snaps(int start, int end, array[] int obs,
45 vector log_exp_obs, int exp_offset,
46 array[] int total, array[] int obs_lookup,
47 array[] int lsl, array[] int clsl,
48 array[] int nsl, array[] int cnsl) {
49 real tar = 0;
50 for (i in start:end) {
51 if (lsl[i]) {
52 // Cutoff block of expected cells; renormalise over all slots up to the
53 // cutoff (log_softmax) so before-cutoff interior cells keep weight.
54 array[3] int c = filt_obs_indexes(i, i, clsl, lsl);
55 vector[lsl[i]] lprob = log_softmax(
56 segment(log_exp_obs, c[1] - exp_offset + 1, lsl[i])
57 );
58 array[3] int o = filt_obs_indexes(i, i, cnsl, nsl);
59 if (nsl[i] == lsl[i]) {
60 // Fully observed up to the cutoff: plain (truncated) multinomial.
61 tar += multinomial_lpmf(segment(obs, o[1], nsl[i]) | exp(lprob));
62 } else {
63 // Some before-cutoff cells unobserved: marginalise them into one
64 // residual category (count = total - observed, prob = 1 - observed).
65 vector[nsl[i] + 1] cat_lprob;
66 array[nsl[i] + 1] int cat_obs;
67 int obs_sum = 0;
68 for (j in 1:nsl[i]) {
69 int local = obs_lookup[o[1] + j - 1] - c[1] + 1;
70 cat_lprob[j] = lprob[local];
71 cat_obs[j] = obs[o[1] + j - 1];
72 obs_sum += cat_obs[j];
73 }
74 // Residual probability = 1 - sum(observed probs), on the log scale.
75 cat_lprob[nsl[i] + 1] = log1m_exp(log_sum_exp(cat_lprob[1:nsl[i]]));
76 cat_obs[nsl[i] + 1] = max(total[i] - obs_sum, 0);
77 tar += multinomial_lpmf(cat_obs | exp(cat_lprob));
78 }
79 }
80 }
81 return tar;
82}
real delay_multinomial_snaps(int start, int end, array[] int obs, vector log_exp_obs, int exp_offset, array[] int total, array[] int obs_lookup, array[] int lsl, array[] int clsl, array[] int nsl, array[] int cnsl)
real delay_multinomial_lpmf(array[] int obs, vector log_exp_obs)
array[] int filt_obs_indexes(int start, int end, array[] int csl, array[] int sl)