kthread.c 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #include <pthread.h>
  2. #include <stdlib.h>
  3. #include <limits.h>
  4. /************
  5. * kt_for() *
  6. ************/
  7. struct kt_for_t;
  8. typedef struct {
  9. struct kt_for_t *t;
  10. long i;
  11. } ktf_worker_t;
  12. typedef struct kt_for_t {
  13. int n_threads;
  14. long n;
  15. ktf_worker_t *w;
  16. void (*func)(void*,long,int);
  17. void *data;
  18. } kt_for_t;
  19. static inline long steal_work(kt_for_t *t)
  20. {
  21. int i, min_i = -1;
  22. long k, min = LONG_MAX;
  23. for (i = 0; i < t->n_threads; ++i)
  24. if (min > t->w[i].i) min = t->w[i].i, min_i = i;
  25. k = __sync_fetch_and_add(&t->w[min_i].i, t->n_threads);
  26. return k >= t->n? -1 : k;
  27. }
  28. static void *ktf_worker(void *data)
  29. {
  30. ktf_worker_t *w = (ktf_worker_t*)data;
  31. long i;
  32. for (;;) {
  33. i = __sync_fetch_and_add(&w->i, w->t->n_threads);
  34. if (i >= w->t->n) break;
  35. w->t->func(w->t->data, i, w - w->t->w);
  36. }
  37. while ((i = steal_work(w->t)) >= 0)
  38. w->t->func(w->t->data, i, w - w->t->w);
  39. pthread_exit(0);
  40. }
  41. void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n)
  42. {
  43. int i;
  44. kt_for_t t;
  45. pthread_t *tid;
  46. t.func = func, t.data = data, t.n_threads = n_threads, t.n = n;
  47. t.w = (ktf_worker_t*)alloca(n_threads * sizeof(ktf_worker_t));
  48. tid = (pthread_t*)alloca(n_threads * sizeof(pthread_t));
  49. for (i = 0; i < n_threads; ++i)
  50. t.w[i].t = &t, t.w[i].i = i;
  51. for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktf_worker, &t.w[i]);
  52. for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);
  53. }
  54. /*****************
  55. * kt_pipeline() *
  56. *****************/
  57. struct ktp_t;
  58. typedef struct {
  59. struct ktp_t *pl;
  60. int step, running;
  61. void *data;
  62. } ktp_worker_t;
  63. typedef struct ktp_t {
  64. void *shared;
  65. void *(*func)(void*, int, void*);
  66. int n_workers, n_steps;
  67. ktp_worker_t *workers;
  68. pthread_mutex_t mutex;
  69. pthread_cond_t cv;
  70. } ktp_t;
  71. static void *ktp_worker(void *data)
  72. {
  73. ktp_worker_t *w = (ktp_worker_t*)data;
  74. ktp_t *p = w->pl;
  75. while (w->step < p->n_steps) {
  76. // test whether we can kick off the job with this worker
  77. pthread_mutex_lock(&p->mutex);
  78. for (;;) {
  79. int i;
  80. // test whether another worker is doing the same step
  81. for (i = 0; i < p->n_workers; ++i) {
  82. if (w == &p->workers[i]) continue; // ignore itself
  83. if (p->workers[i].running && p->workers[i].step == w->step)
  84. break;
  85. }
  86. if (i == p->n_workers) break; // no other workers doing w->step; then this worker will
  87. pthread_cond_wait(&p->cv, &p->mutex);
  88. }
  89. w->running = 1;
  90. pthread_mutex_unlock(&p->mutex);
  91. // working on w->step
  92. w->data = p->func(p->shared, w->step, w->step? w->data : 0); // for the first step, input is NULL
  93. // update step and let other workers know
  94. pthread_mutex_lock(&p->mutex);
  95. w->step = w->step == p->n_steps - 1 || w->data? (w->step + 1) % p->n_steps : p->n_steps;
  96. w->running = 0;
  97. pthread_cond_broadcast(&p->cv);
  98. pthread_mutex_unlock(&p->mutex);
  99. }
  100. pthread_exit(0);
  101. }
  102. void kt_pipeline(int n_threads, void *(*func)(void*, int, void*), void *shared_data, int n_steps)
  103. {
  104. ktp_t aux;
  105. pthread_t *tid;
  106. int i;
  107. if (n_threads < 1) n_threads = 1;
  108. aux.n_workers = n_threads;
  109. aux.n_steps = n_steps;
  110. aux.func = func;
  111. aux.shared = shared_data;
  112. pthread_mutex_init(&aux.mutex, 0);
  113. pthread_cond_init(&aux.cv, 0);
  114. aux.workers = alloca(n_threads * sizeof(ktp_worker_t));
  115. for (i = 0; i < n_threads; ++i) {
  116. ktp_worker_t *w = &aux.workers[i];
  117. w->step = w->running = 0; w->pl = &aux; w->data = 0;
  118. }
  119. tid = alloca(n_threads * sizeof(pthread_t));
  120. for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktp_worker, &aux.workers[i]);
  121. for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);
  122. pthread_mutex_destroy(&aux.mutex);
  123. pthread_cond_destroy(&aux.cv);
  124. }