khmm.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <assert.h>
  4. #include <string.h>
  5. #include <stdlib.h>
  6. #include "khmm.h"
  7. // new/delete hmm_par_t
  8. hmm_par_t *hmm_new_par(int m, int n)
  9. {
  10. hmm_par_t *hp;
  11. int i;
  12. assert(m > 0 && n > 0);
  13. hp = (hmm_par_t*)calloc(1, sizeof(hmm_par_t));
  14. hp->m = m; hp->n = n;
  15. hp->a0 = (FLOAT*)calloc(n, sizeof(FLOAT));
  16. hp->a = (FLOAT**)calloc2(n, n, sizeof(FLOAT));
  17. hp->e = (FLOAT**)calloc2(m + 1, n, sizeof(FLOAT));
  18. hp->ae = (FLOAT**)calloc2((m + 1) * n, n, sizeof(FLOAT));
  19. for (i = 0; i != n; ++i) hp->e[m][i] = 1.0;
  20. return hp;
  21. }
  22. void hmm_delete_par(hmm_par_t *hp)
  23. {
  24. int i;
  25. if (hp == 0) return;
  26. for (i = 0; i != hp->n; ++i) free(hp->a[i]);
  27. for (i = 0; i <= hp->m; ++i) free(hp->e[i]);
  28. for (i = 0; i < (hp->m + 1) * hp->n; ++i) free(hp->ae[i]);
  29. free(hp->a); free(hp->e); free(hp->a0); free(hp->ae);
  30. free(hp);
  31. }
  32. // new/delete hmm_data_t
  33. hmm_data_t *hmm_new_data(int L, const char *seq, const hmm_par_t *hp)
  34. {
  35. hmm_data_t *hd;
  36. hd = (hmm_data_t*)calloc(1, sizeof(hmm_data_t));
  37. hd->L = L;
  38. hd->seq = (char*)malloc(L + 1);
  39. memcpy(hd->seq + 1, seq, L);
  40. return hd;
  41. }
  42. void hmm_delete_data(hmm_data_t *hd)
  43. {
  44. int i;
  45. if (hd == 0) return;
  46. for (i = 0; i <= hd->L; ++i) {
  47. if (hd->f) free(hd->f[i]);
  48. if (hd->b) free(hd->b[i]);
  49. }
  50. free(hd->f); free(hd->b); free(hd->s); free(hd->v); free(hd->p); free(hd->seq);
  51. free(hd);
  52. }
  53. // new/delete hmm_exp_t
  54. hmm_exp_t *hmm_new_exp(const hmm_par_t *hp)
  55. {
  56. hmm_exp_t *he;
  57. assert(hp);
  58. he = (hmm_exp_t*)calloc(1, sizeof(hmm_exp_t));
  59. he->m = hp->m; he->n = hp->n;
  60. he->A0 = (FLOAT*)calloc(hp->n, sizeof(FLOAT));
  61. he->A = (FLOAT**)calloc2(hp->n, hp->n, sizeof(FLOAT));
  62. he->E = (FLOAT**)calloc2(hp->m + 1, hp->n, sizeof(FLOAT));
  63. return he;
  64. }
  65. void hmm_delete_exp(hmm_exp_t *he)
  66. {
  67. int i;
  68. if (he == 0) return;
  69. for (i = 0; i != he->n; ++i) free(he->A[i]);
  70. for (i = 0; i <= he->m; ++i) free(he->E[i]);
  71. free(he->A); free(he->E); free(he->A0);
  72. free(he);
  73. }
  74. // Viterbi algorithm
  75. FLOAT hmm_Viterbi(const hmm_par_t *hp, hmm_data_t *hd)
  76. {
  77. FLOAT **la, **le, *preV, *curV, max;
  78. int **Vmax, max_l; // backtrace matrix
  79. int k, l, b, u;
  80. if (hd->v) free(hd->v);
  81. hd->v = (int*)calloc(hd->L+1, sizeof(int));
  82. la = (FLOAT**)calloc2(hp->n, hp->n, sizeof(FLOAT));
  83. le = (FLOAT**)calloc2(hp->m + 1, hp->n, sizeof(FLOAT));
  84. Vmax = (int**)calloc2(hd->L+1, hp->n, sizeof(int));
  85. preV = (FLOAT*)malloc(sizeof(FLOAT) * hp->n);
  86. curV = (FLOAT*)malloc(sizeof(FLOAT) * hp->n);
  87. for (k = 0; k != hp->n; ++k)
  88. for (l = 0; l != hp->n; ++l)
  89. la[k][l] = log(hp->a[l][k]); // this is not a bug
  90. for (b = 0; b != hp->m; ++b)
  91. for (k = 0; k != hp->n; ++k)
  92. le[b][k] = log(hp->e[b][k]);
  93. for (k = 0; k != hp->n; ++k) le[hp->m][k] = 0.0;
  94. // V_k(1)
  95. for (k = 0; k != hp->n; ++k) {
  96. preV[k] = le[(int)hd->seq[1]][k] + log(hp->a0[k]);
  97. Vmax[1][k] = 0;
  98. }
  99. // all the rest
  100. for (u = 2; u <= hd->L; ++u) {
  101. FLOAT *tmp, *leu = le[(int)hd->seq[u]];
  102. for (k = 0; k != hp->n; ++k) {
  103. FLOAT *laa = la[k];
  104. for (l = 0, max = -HMM_INF, max_l = -1; l != hp->n; ++l) {
  105. if (max < preV[l] + laa[l]) {
  106. max = preV[l] + laa[l];
  107. max_l = l;
  108. }
  109. }
  110. assert(max_l >= 0); // cannot be zero
  111. curV[k] = leu[k] + max;
  112. Vmax[u][k] = max_l;
  113. }
  114. tmp = curV; curV = preV; preV = tmp; // swap
  115. }
  116. // backtrace
  117. for (k = 0, max_l = -1, max = -HMM_INF; k != hp->n; ++k) {
  118. if (max < preV[k]) {
  119. max = preV[k]; max_l = k;
  120. }
  121. }
  122. assert(max_l >= 0); // cannot be zero
  123. hd->v[hd->L] = max_l;
  124. for (u = hd->L; u >= 1; --u)
  125. hd->v[u-1] = Vmax[u][hd->v[u]];
  126. for (k = 0; k != hp->n; ++k) free(la[k]);
  127. for (b = 0; b < hp->m; ++b) free(le[b]);
  128. for (u = 0; u <= hd->L; ++u) free(Vmax[u]);
  129. free(la); free(le); free(Vmax); free(preV); free(curV);
  130. hd->status |= HMM_VITERBI;
  131. return max;
  132. }
  133. // forward algorithm
  134. void hmm_forward(const hmm_par_t *hp, hmm_data_t *hd)
  135. {
  136. FLOAT sum, tmp, **at;
  137. int u, k, l;
  138. int n, m, L;
  139. assert(hp && hd);
  140. // allocate memory for hd->f and hd->s
  141. n = hp->n; m = hp->m; L = hd->L;
  142. if (hd->s) free(hd->s);
  143. if (hd->f) {
  144. for (k = 0; k <= hd->L; ++k) free(hd->f[k]);
  145. free(hd->f);
  146. }
  147. hd->f = (FLOAT**)calloc2(hd->L+1, hp->n, sizeof(FLOAT));
  148. hd->s = (FLOAT*)calloc(hd->L+1, sizeof(FLOAT));
  149. hd->status &= ~(unsigned)HMM_FORWARD;
  150. // at[][] array helps to improve the cache efficiency
  151. at = (FLOAT**)calloc2(n, n, sizeof(FLOAT));
  152. // transpose a[][]
  153. for (k = 0; k != n; ++k)
  154. for (l = 0; l != n; ++l)
  155. at[k][l] = hp->a[l][k];
  156. // f[0], but it should never be used
  157. hd->s[0] = 1.0;
  158. for (k = 0; k != n; ++k) hd->f[0][k] = 0.0;
  159. // f[1]
  160. for (k = 0, sum = 0.0; k != n; ++k)
  161. sum += (hd->f[1][k] = hp->a0[k] * hp->e[(int)hd->seq[1]][k]);
  162. for (k = 0; k != n; ++k) hd->f[1][k] /= sum;
  163. hd->s[1] = sum;
  164. // f[2..hmmL], the core loop
  165. for (u = 2; u <= L; ++u) {
  166. FLOAT *fu = hd->f[u], *fu1 = hd->f[u-1], *eu = hp->e[(int)hd->seq[u]];
  167. for (k = 0, sum = 0.0; k != n; ++k) {
  168. FLOAT *aa = at[k];
  169. for (l = 0, tmp = 0.0; l != n; ++l) tmp += fu1[l] * aa[l];
  170. sum += (fu[k] = eu[k] * tmp);
  171. }
  172. for (k = 0; k != n; ++k) fu[k] /= sum;
  173. hd->s[u] = sum;
  174. }
  175. // free at array
  176. for (k = 0; k != hp->n; ++k) free(at[k]);
  177. free(at);
  178. hd->status |= HMM_FORWARD;
  179. }
  180. // precalculate hp->ae
  181. void hmm_pre_backward(hmm_par_t *hp)
  182. {
  183. int m, n, b, k, l;
  184. assert(hp);
  185. m = hp->m; n = hp->n;
  186. for (b = 0; b <= m; ++b) {
  187. for (k = 0; k != n; ++k) {
  188. FLOAT *p = hp->ae[b * hp->n + k];
  189. for (l = 0; l != n; ++l)
  190. p[l] = hp->e[b][l] * hp->a[k][l];
  191. }
  192. }
  193. }
  194. // backward algorithm
  195. void hmm_backward(const hmm_par_t *hp, hmm_data_t *hd)
  196. {
  197. FLOAT tmp;
  198. int k, l, u;
  199. int m, n, L;
  200. assert(hp && hd);
  201. assert(hd->status & HMM_FORWARD);
  202. // allocate memory for hd->b
  203. m = hp->m; n = hp->n; L = hd->L;
  204. if (hd->b) {
  205. for (k = 0; k <= hd->L; ++k) free(hd->b[k]);
  206. free(hd->b);
  207. }
  208. hd->status &= ~(unsigned)HMM_BACKWARD;
  209. hd->b = (FLOAT**)calloc2(L+1, hp->n, sizeof(FLOAT));
  210. // b[L]
  211. for (k = 0; k != hp->n; ++k) hd->b[L][k] = 1.0 / hd->s[L];
  212. // b[1..L-1], the core loop
  213. for (u = L-1; u >= 1; --u) {
  214. FLOAT *bu1 = hd->b[u+1], **p = hp->ae + (int)hd->seq[u+1] * n;
  215. for (k = 0; k != n; ++k) {
  216. FLOAT *q = p[k];
  217. for (l = 0, tmp = 0.0; l != n; ++l) tmp += q[l] * bu1[l];
  218. hd->b[u][k] = tmp / hd->s[u];
  219. }
  220. }
  221. hd->status |= HMM_BACKWARD;
  222. for (l = 0, tmp = 0.0; l != n; ++l)
  223. tmp += hp->a0[l] * hd->b[1][l] * hp->e[(int)hd->seq[1]][l];
  224. if (tmp > 1.0 + 1e-6 || tmp < 1.0 - 1e-6) // in theory, tmp should always equal to 1
  225. fprintf(stderr, "++ Underflow may have happened (%lg).\n", tmp);
  226. }
  227. // log-likelihood of the observation
  228. FLOAT hmm_lk(const hmm_data_t *hd)
  229. {
  230. FLOAT sum = 0.0, prod = 1.0;
  231. int u, L;
  232. L = hd->L;
  233. assert(hd->status & HMM_FORWARD);
  234. for (u = 1; u <= L; ++u) {
  235. prod *= hd->s[u];
  236. if (prod < HMM_TINY || prod >= 1.0/HMM_TINY) { // reset
  237. sum += log(prod);
  238. prod = 1.0;
  239. }
  240. }
  241. sum += log(prod);
  242. return sum;
  243. }
  244. // posterior decoding
  245. void hmm_post_decode(const hmm_par_t *hp, hmm_data_t *hd)
  246. {
  247. int u, k;
  248. assert(hd->status && HMM_BACKWARD);
  249. if (hd->p) free(hd->p);
  250. hd->p = (int*)calloc(hd->L + 1, sizeof(int));
  251. for (u = 1; u <= hd->L; ++u) {
  252. FLOAT prob, max, *fu = hd->f[u], *bu = hd->b[u], su = hd->s[u];
  253. int max_k;
  254. for (k = 0, max = -1.0, max_k = -1; k != hp->n; ++k) {
  255. if (max < (prob = fu[k] * bu[k] * su)) {
  256. max = prob; max_k = k;
  257. }
  258. }
  259. assert(max_k >= 0);
  260. hd->p[u] = max_k;
  261. }
  262. hd->status |= HMM_POSTDEC;
  263. }
  264. // posterior probability of states
  265. FLOAT hmm_post_state(const hmm_par_t *hp, const hmm_data_t *hd, int u, FLOAT *prob)
  266. {
  267. FLOAT sum = 0.0, ss = hd->s[u], *fu = hd->f[u], *bu = hd->b[u];
  268. int k;
  269. for (k = 0; k != hp->n; ++k)
  270. sum += (prob[k] = fu[k] * bu[k] * ss);
  271. return sum; // in theory, this should always equal to 1.0
  272. }
  273. // expected counts
  274. hmm_exp_t *hmm_expect(const hmm_par_t *hp, const hmm_data_t *hd)
  275. {
  276. int k, l, u, b, m, n;
  277. hmm_exp_t *he;
  278. assert(hd->status & HMM_BACKWARD);
  279. he = hmm_new_exp(hp);
  280. // initialization
  281. m = hp->m; n = hp->n;
  282. for (k = 0; k != n; ++k)
  283. for (l = 0; l != n; ++l) he->A[k][l] = HMM_TINY;
  284. for (b = 0; b <= m; ++b)
  285. for (l = 0; l != n; ++l) he->E[b][l] = HMM_TINY;
  286. // calculate A_{kl} and E_k(b), k,l\in[0,n)
  287. for (u = 1; u < hd->L; ++u) {
  288. FLOAT *fu = hd->f[u], *bu = hd->b[u], *bu1 = hd->b[u+1], ss = hd->s[u];
  289. FLOAT *Ec = he->E[(int)hd->seq[u]], **p = hp->ae + (int)hd->seq[u+1] * n;
  290. for (k = 0; k != n; ++k) {
  291. FLOAT *q = p[k], *AA = he->A[k], fuk = fu[k];
  292. for (l = 0; l != n; ++l) // this is cache-efficient
  293. AA[l] += fuk * q[l] * bu1[l];
  294. Ec[k] += fuk * bu[k] * ss;
  295. }
  296. }
  297. // calculate A0_l
  298. for (l = 0; l != n; ++l)
  299. he->A0[l] += hp->a0[l] * hp->e[(int)hd->seq[1]][l] * hd->b[1][l];
  300. return he;
  301. }
  302. FLOAT hmm_Q0(const hmm_par_t *hp, hmm_exp_t *he)
  303. {
  304. int k, l, b;
  305. FLOAT sum = 0.0;
  306. for (k = 0; k != hp->n; ++k) {
  307. FLOAT tmp;
  308. for (b = 0, tmp = 0.0; b != hp->m; ++b) tmp += he->E[b][k];
  309. for (b = 0; b != hp->m; ++b)
  310. sum += he->E[b][k] * log(he->E[b][k] / tmp);
  311. }
  312. for (k = 0; k != hp->n; ++k) {
  313. FLOAT tmp, *A = he->A[k];
  314. for (l = 0, tmp = 0.0; l != hp->n; ++l) tmp += A[l];
  315. for (l = 0; l != hp->n; ++l) sum += A[l] * log(A[l] / tmp);
  316. }
  317. return (he->Q0 = sum);
  318. }
  319. // add he0 to he1
  320. void hmm_add_expect(const hmm_exp_t *he0, hmm_exp_t *he1)
  321. {
  322. int b, k, l;
  323. assert(he0->m == he1->m && he0->n == he1->n);
  324. for (k = 0; k != he1->n; ++k) {
  325. he1->A0[k] += he0->A0[k];
  326. for (l = 0; l != he1->n; ++l)
  327. he1->A[k][l] += he0->A[k][l];
  328. }
  329. for (b = 0; b != he1->m; ++b) {
  330. for (l = 0; l != he1->n; ++l)
  331. he1->E[b][l] += he0->E[b][l];
  332. }
  333. }
  334. // the EM-Q function
  335. FLOAT hmm_Q(const hmm_par_t *hp, const hmm_exp_t *he)
  336. {
  337. FLOAT sum = 0.0;
  338. int bb, k, l;
  339. for (bb = 0; bb != he->m; ++bb) {
  340. FLOAT *eb = hp->e[bb], *Eb = he->E[bb];
  341. for (k = 0; k != hp->n; ++k) {
  342. if (eb[k] <= 0.0) return -HMM_INF;
  343. sum += Eb[k] * log(eb[k]);
  344. }
  345. }
  346. for (k = 0; k != he->n; ++k) {
  347. FLOAT *Ak = he->A[k], *ak = hp->a[k];
  348. for (l = 0; l != he->n; ++l) {
  349. if (ak[l] <= 0.0) return -HMM_INF;
  350. sum += Ak[l] * log(ak[l]);
  351. }
  352. }
  353. return (sum -= he->Q0);
  354. }
  355. // simulate sequence
  356. char *hmm_simulate(const hmm_par_t *hp, int L)
  357. {
  358. int i, k, l, b;
  359. FLOAT x, y, **et;
  360. char *seq;
  361. seq = (char*)calloc(L+1, 1);
  362. // calculate the transpose of hp->e[][]
  363. et = (FLOAT**)calloc2(hp->n, hp->m, sizeof(FLOAT));
  364. for (k = 0; k != hp->n; ++k)
  365. for (b = 0; b != hp->m; ++b)
  366. et[k][b] = hp->e[b][k];
  367. // the initial state, drawn from a0[]
  368. x = drand48();
  369. for (k = 0, y = 0.0; k != hp->n; ++k) {
  370. y += hp->a0[k];
  371. if (y >= x) break;
  372. }
  373. // main loop
  374. for (i = 0; i != L; ++i) {
  375. FLOAT *el, *ak = hp->a[k];
  376. x = drand48();
  377. for (l = 0, y = 0.0; l != hp->n; ++l) {
  378. y += ak[l];
  379. if (y >= x) break;
  380. }
  381. el = et[l];
  382. x = drand48();
  383. for (b = 0, y = 0.0; b != hp->m; ++b) {
  384. y += el[b];
  385. if (y >= x) break;
  386. }
  387. seq[i] = b;
  388. k = l;
  389. }
  390. for (k = 0; k != hp->n; ++k) free(et[k]);
  391. free(et);
  392. return seq;
  393. }