From 2bd96d9ecd73b71b3666c7d1931ec3e33e5f49fb Mon Sep 17 00:00:00 2001 From: Daniel Llorens Date: Thu, 25 Apr 2013 15:18:05 +0200 Subject: [PATCH] Fix corner cases of scm_ramapc * libguile/array-map.c - (scm_ramapc): mismatched axes limit unrollk (kroll). Reorganize the function to do all checking as we go. - (scm_ra_matchp): unused; remove. - (find_unrollk): inlined in scm_ramapc; remove. - (klen): inlined in scm_ramapc; remove. - (rafill): n is size_t. - (racp): n is size_t. Use n and not i0end to bound the loop. - (ramap): Use n and not i0end to bound the loop. This is needed for the rank 0 case to work with the new scm_ramapc, as inc may be set to 0 in that case. - (rafe): idem. * test-suite/tests/ramap.test - check that size mismatch prevents unrolling (matching behavior III) with both array-copy! and array-map!. - check that non-contiguous stride in non-ref args prevents unrolling (rank 2, discontinuous) with both array-copy! and array-map!. - check rank 0 cases with array-for-each, array-map!. --- libguile/array-map.c | 371 +++++++++++++++--------------------- test-suite/tests/ramap.test | 62 +++++- 2 files changed, 218 insertions(+), 215 deletions(-) diff --git a/libguile/array-map.c b/libguile/array-map.c index 8cd97aafe..70e3e676f 100644 --- a/libguile/array-map.c +++ b/libguile/array-map.c @@ -70,72 +70,6 @@ ASET (SCM v, size_t pos, SCM val) scm_array_handle_release (&h); } -/* Checker for scm_array mapping functions, returns: - - 5 --> empty axes; - 4 --> shapes, increments, and bases are the same; - 3 --> shapes and increments are the same; - 2 --> shapes are the same; - 1 --> ras are at least as big as ra0; - 0 --> no match. - */ - -int -scm_ra_matchp (SCM ra0, SCM ras) -{ - int i, exact = 4, empty = 0; - scm_t_array_handle h0; - - scm_array_get_handle (ra0, &h0); - for (i = 0; i < h0.ndims; ++i) - { - empty = empty || (h0.dims[i].lbnd > h0.dims[i].ubnd); - } - - while (scm_is_pair (ras)) - { - scm_t_array_handle h1; - - scm_array_get_handle (SCM_CAR (ras), &h1); - - if (h0.ndims != h1.ndims) - { - scm_array_handle_release (&h0); - scm_array_handle_release (&h1); - return 0; - } - if (h0.base != h1.base) - exact = min(3, exact); - - for (i = 0; i < h0.ndims; ++i) - { - empty = empty || (h1.dims[i].lbnd > h1.dims[i].ubnd); - switch (exact) - { - case 4: - case 3: - if (h0.dims[i].inc != h1.dims[i].inc) - exact = 2; - case 2: - if (h0.dims[i].lbnd == h1.dims[i].lbnd && h0.dims[i].ubnd == h1.dims[i].ubnd) - break; - exact = 1; - default: - if (h0.dims[i].lbnd < h1.dims[i].lbnd || h0.dims[i].ubnd > h1.dims[i].ubnd) - { - scm_array_handle_release (&h0); - scm_array_handle_release (&h1); - return 0; - } - } - } - scm_array_handle_release (&h1); - ras = SCM_CDR (ras); - } - scm_array_handle_release (&h0); - return empty ? 5 : exact; -} - static SCM make1array (SCM v, ssize_t inc) { @@ -148,46 +82,11 @@ make1array (SCM v, ssize_t inc) return a; } -/* Find down to which rank the array is unrollable. 0 means fully - unrollable, which all rank-0 and rank-1 arrays are. */ -static int -find_unrollk (SCM ra, int k) -{ - if (k <= 0) - return 0; - else - { - ssize_t inc; - inc = SCM_I_ARRAY_DIMS (ra)[k].inc; - do { - size_t lenk = (SCM_I_ARRAY_DIMS (ra)[k].ubnd - - SCM_I_ARRAY_DIMS (ra)[k].lbnd + 1); - inc *= lenk; - --k; - } while (k >= 0 && inc == SCM_I_ARRAY_DIMS (ra)[k].inc); - return k+1; - } -} - -/* Length of the unrolled index set. */ -static size_t -klen (SCM ra, int kbegin, int kend) -{ - size_t len = 1; - int k; - for (k = kbegin; k < kend; ++k) - len *= (SCM_I_ARRAY_DIMS (ra)[k].ubnd - - SCM_I_ARRAY_DIMS (ra)[k].lbnd + 1); - return len; -} - -/* Linear index of the NOT unrolled index set. */ +/* Linear index of not-unrolled index set. */ static size_t cindk (SCM ra, ssize_t *ve, int kend) { - if (!SCM_I_ARRAYP (ra)) - return 0; /* this is BASE */ - else + if (SCM_I_ARRAYP (ra)) { int k; size_t i = SCM_I_ARRAY_BASE (ra); @@ -195,6 +94,8 @@ cindk (SCM ra, ssize_t *ve, int kend) i += (ve[k] - SCM_I_ARRAY_DIMS (ra)[k].lbnd) * SCM_I_ARRAY_DIMS (ra)[k].inc; return i; } + else + return 0; /* this is BASE */ } /* array mapper: apply cproc to each dimension of the given arrays?. @@ -205,120 +106,163 @@ cindk (SCM ra, ssize_t *ve, int kend) SCM ra0; destination array. SCM lra; list of source arrays. const char *what; caller, for error reporting. */ + +#define LBND(ra, k) SCM_I_ARRAY_DIMS (ra)[k].lbnd +#define UBND(ra, k) SCM_I_ARRAY_DIMS (ra)[k].ubnd + int scm_ramapc (void *cproc_ptr, SCM data, SCM ra0, SCM lra, const char *what) { - SCM z; - SCM vra0; - SCM lvra, *plvra; + SCM z, vra0, lvra, *plvra; ssize_t *vi; - int k, kmax, unrollk; + int k, kmax, kroll; int (*cproc) () = cproc_ptr; - size_t unrolled_len; + int empty = 0; - switch (scm_ra_matchp (ra0, lra)) + /* Prepare reference argument. */ + if (SCM_I_ARRAYP (ra0)) { - default: - case 0: - scm_misc_error (what, "array shape mismatch: ~S", scm_list_1 (ra0)); - case 1: - case 2: - case 3: - case 4: + k = kmax = SCM_I_ARRAY_NDIM (ra0)-1; + vra0 = make1array (SCM_I_ARRAY_V (ra0), SCM_I_ARRAY_DIMS (ra0)[kmax].inc); - /* Prepare reference argument */ - if (SCM_I_ARRAYP (ra0)) + /* Find unroll depth */ + if (k > 0) { - kmax = SCM_I_ARRAY_NDIM (ra0)-1; - vra0 = make1array (SCM_I_ARRAY_V (ra0), SCM_I_ARRAY_DIMS (ra0)[kmax].inc); + ssize_t inc = SCM_I_ARRAY_DIMS (ra0)[k].inc; + do { + inc *= (UBND (ra0, k) - LBND (ra0, k) + 1); + --k; + } while (k >= 0 && inc == SCM_I_ARRAY_DIMS (ra0)[k].inc); + kroll = k+1; + empty = 0 == inc; + } + else + kroll = 0; + + /* Check emptiness of not-unrolled axes. */ + for (; k>=0 && !empty; --k) + empty = (0 == (UBND (ra0, k) - LBND (ra0, k) + 1)); + } + else + { + kroll = kmax = 0; + vra0 = ra0 = make1array (ra0, 1); + empty = (0 == (UBND (ra0, 0) - LBND (ra0, 0) + 1)); + } + + /* Prepare rest arguments. */ + lvra = SCM_EOL; + plvra = &lvra; + for (z = lra; !scm_is_null (z); z = SCM_CDR (z)) + { + SCM ra1 = SCM_CAR (z); + SCM vra1; + if (SCM_I_ARRAYP (ra1)) + { + if (kmax != SCM_I_ARRAY_NDIM (ra1) - 1) + scm_misc_error (what, "array shape mismatch: ~S", scm_list_1 (ra0)); + vra1 = make1array (SCM_I_ARRAY_V (ra1), SCM_I_ARRAY_DIMS (ra1)[kmax].inc); + + /* Check unroll depth. */ + k = kmax; + if (k > kroll) + { + ssize_t inc = SCM_I_ARRAY_DIMS (ra1)[k].inc; + do { + ssize_t l0 = LBND (ra0, k), u0 = UBND (ra0, k); + ssize_t l1 = LBND (ra1, k), u1 = UBND (ra1, k); + --k; + if (l0 == l1 && u0 == u1) + inc *= (u1 - l1 + 1); + else if (l0 >= l1 && u0 <= u1) + break; + else + scm_misc_error (what, "array shape mismatch: ~S", scm_list_1 (ra0)); + } while (k >= kroll && inc == SCM_I_ARRAY_DIMS (ra1)[k].inc); + kroll = k + 1; + } + + /* Check matching of not-unrolled axes. */ + for (; k>=0; --k) + if (LBND (ra0, k) < LBND (ra1, k) || UBND (ra0, k) > UBND (ra1, k)) + scm_misc_error (what, "array shape mismatch: ~S", scm_list_1 (ra0)); } else { - kmax = 0; - vra0 = ra0 = make1array(ra0, 1); + if (kmax != 0) + scm_misc_error (what, "array shape mismatch: ~S", scm_list_1 (ra0)); + vra1 = make1array (ra1, 1); + + if (LBND (ra0, 0) < LBND (vra1, 0) || UBND (ra0, 0) > UBND (vra1, 0)) + scm_misc_error (what, "array shape mismatch: ~S", scm_list_1 (ra0)); } - - /* Linear addressing for rest arguments */ - lvra = SCM_EOL; - plvra = &lvra; - for (z = lra; !scm_is_null (z); z = SCM_CDR (z)) - { - SCM ra1 = SCM_CAR (z); - SCM vra1; - if (SCM_I_ARRAYP (ra1)) - vra1 = make1array (SCM_I_ARRAY_V (ra1), SCM_I_ARRAY_DIMS (ra1)[kmax].inc); - else - vra1 = make1array (ra1, 1); - *plvra = scm_cons (vra1, SCM_EOL); - plvra = SCM_CDRLOC (*plvra); - } - - /* Find common unroll depth */ - unrollk = find_unrollk (ra0, kmax); - for (z = lra; !scm_is_null (z); z = SCM_CDR (z)) - { - SCM ra1 = SCM_CAR (z); - unrollk = max(unrollk, find_unrollk (ra1, kmax)); - } - unrolled_len = klen (ra0, unrollk, kmax+1); - - /* Set inner loop size */ - SCM_I_ARRAY_DIMS (vra0)->lbnd = 0; - SCM_I_ARRAY_DIMS (vra0)->ubnd = unrolled_len - 1; - for (z = lvra; !scm_is_null (z); z = SCM_CDR (z)) - { - SCM_I_ARRAY_DIMS (SCM_CAR (z))->lbnd = 0; - SCM_I_ARRAY_DIMS (SCM_CAR (z))->ubnd = unrolled_len - 1; - } - - /* Set starting indices and go */ - vi = scm_gc_malloc_pointerless (sizeof(ssize_t) * unrollk, vi_gc_hint); - for (k = 0; k < unrollk; ++k) - vi[k] = SCM_I_ARRAY_DIMS (ra0)[k].lbnd; - do - { - if (k == unrollk) - { - SCM y = lra; - SCM_I_ARRAY_BASE (vra0) = cindk (ra0, vi, unrollk); - for (z = lvra; !scm_is_null (z); z = SCM_CDR (z), y = SCM_CDR (y)) - SCM_I_ARRAY_BASE (SCM_CAR (z)) = cindk (SCM_CAR (y), vi, unrollk); - if (SCM_UNBNDP (data)) - cproc (vra0, lvra); - else - cproc (vra0, data, lvra); - k--; - } - else if (vi[k] < SCM_I_ARRAY_DIMS (ra0)[k].ubnd) - { - vi[k]++; - k++; - } - else - { - vi[k] = SCM_I_ARRAY_DIMS (ra0)[k].lbnd - 1; - k--; - } - } - while (k >= 0); - - return 1; + *plvra = scm_cons (vra1, SCM_EOL); + plvra = SCM_CDRLOC (*plvra); } + + /* Set unrolled size. */ + if (empty) + return 1; + else + { + size_t len = 1; + for (k = kroll; k <= kmax; ++k) + len *= (UBND (ra0, k) - LBND (ra0, k) + 1); + UBND (vra0, 0) = len - 1; + for (z = lvra; !scm_is_null (z); z = SCM_CDR (z)) + UBND (SCM_CAR (z), 0) = len - 1; + } + + /* Set starting indices and go. */ + vi = scm_gc_malloc_pointerless (sizeof(ssize_t) * kroll, vi_gc_hint); + for (k = 0; k < kroll; ++k) + vi[k] = LBND (ra0, k); + do + { + if (k == kroll) + { + SCM y = lra; + SCM_I_ARRAY_BASE (vra0) = cindk (ra0, vi, kroll); + for (z = lvra; !scm_is_null (z); z = SCM_CDR (z), y = SCM_CDR (y)) + SCM_I_ARRAY_BASE (SCM_CAR (z)) = cindk (SCM_CAR (y), vi, kroll); + if (SCM_UNBNDP (data)) + cproc (vra0, lvra); + else + cproc (vra0, data, lvra); + k--; + } + else if (vi[k] < SCM_I_ARRAY_DIMS (ra0)[k].ubnd) + { + vi[k]++; + k++; + } + else + { + vi[k] = SCM_I_ARRAY_DIMS (ra0)[k].lbnd - 1; + k--; + } + } + while (k >= 0); + return 1; } +#undef UBND +#undef LBND + static int rafill (SCM dst, SCM fill) { - long n = (SCM_I_ARRAY_DIMS (dst)->ubnd - SCM_I_ARRAY_DIMS (dst)->lbnd + 1); scm_t_array_handle h; - size_t i; + size_t n, i; ssize_t inc; scm_array_get_handle (SCM_I_ARRAY_V (dst), &h); i = SCM_I_ARRAY_BASE (dst); inc = SCM_I_ARRAY_DIMS (dst)->inc; + n = (SCM_I_ARRAY_DIMS (dst)->ubnd - SCM_I_ARRAY_DIMS (dst)->lbnd + 1); + dst = SCM_I_ARRAY_V (dst); for (; n-- > 0; i += inc) - h.impl->vset (SCM_I_ARRAY_V (dst), i, fill); + h.impl->vset (dst, i, fill); scm_array_handle_release (&h); return 1; @@ -339,9 +283,8 @@ SCM_DEFINE (scm_array_fill_x, "array-fill!", 2, 0, 0, static int racp (SCM src, SCM dst) { - ssize_t n = (SCM_I_ARRAY_DIMS (src)->ubnd - SCM_I_ARRAY_DIMS (src)->lbnd + 1); scm_t_array_handle h_s, h_d; - size_t i_s, i_d; + size_t n, i_s, i_d; ssize_t inc_s, inc_d; dst = SCM_CAR (dst); @@ -349,11 +292,14 @@ racp (SCM src, SCM dst) i_d = SCM_I_ARRAY_BASE (dst); inc_s = SCM_I_ARRAY_DIMS (src)->inc; inc_d = SCM_I_ARRAY_DIMS (dst)->inc; + n = (SCM_I_ARRAY_DIMS (src)->ubnd - SCM_I_ARRAY_DIMS (src)->lbnd + 1); + src = SCM_I_ARRAY_V (src); + dst = SCM_I_ARRAY_V (dst); - scm_array_get_handle (SCM_I_ARRAY_V (src), &h_s); - scm_array_get_handle (SCM_I_ARRAY_V (dst), &h_d); + scm_array_get_handle (src, &h_s); + scm_array_get_handle (dst, &h_d); - if (scm_is_vector (SCM_I_ARRAY_V (src)) && scm_is_vector (SCM_I_ARRAY_V (dst))) + if (scm_is_vector (src) && scm_is_vector (dst)) { SCM const * el_s = h_s.elements; SCM * el_d = h_d.writable_elements; @@ -362,7 +308,7 @@ racp (SCM src, SCM dst) } else for (; n-- > 0; i_s += inc_s, i_d += inc_d) - h_d.impl->vset (SCM_I_ARRAY_V (dst), i_d, h_s.impl->vref (SCM_I_ARRAY_V (src), i_s)); + h_d.impl->vset (dst, i_d, h_s.impl->vref (src, i_s)); scm_array_handle_release (&h_d); scm_array_handle_release (&h_s); @@ -652,19 +598,18 @@ scm_array_identity (SCM dst, SCM src) static int ramap (SCM ra0, SCM proc, SCM ras) { - ssize_t i = SCM_I_ARRAY_DIMS (ra0)->lbnd; - size_t n = SCM_I_ARRAY_DIMS (ra0)->ubnd - i + 1; - scm_t_array_handle h0; - size_t i0, i0end; - ssize_t inc0; + size_t n, i0; + ssize_t i, inc0; scm_array_get_handle (SCM_I_ARRAY_V (ra0), &h0); i0 = SCM_I_ARRAY_BASE (ra0); inc0 = SCM_I_ARRAY_DIMS (ra0)->inc; - i0end = i0 + n*inc0; + i = SCM_I_ARRAY_DIMS (ra0)->lbnd; + n = SCM_I_ARRAY_DIMS (ra0)->ubnd - i + 1; + ra0 = SCM_I_ARRAY_V (ra0); if (scm_is_null (ras)) - for (; i0 < i0end; i0 += inc0) - h0.impl->vset (SCM_I_ARRAY_V (ra0), i0, scm_call_0 (proc)); + for (; n--; i0 += inc0) + h0.impl->vset (ra0, i0, scm_call_0 (proc)); else { SCM ra1 = SCM_CAR (ras); @@ -675,19 +620,20 @@ ramap (SCM ra0, SCM proc, SCM ras) i1 = SCM_I_ARRAY_BASE (ra1); inc1 = SCM_I_ARRAY_DIMS (ra1)->inc; ras = SCM_CDR (ras); + ra1 = SCM_I_ARRAY_V (ra1); if (scm_is_null (ras)) - for (; i0 < i0end; i0 += inc0, i1 += inc1) - h0.impl->vset (SCM_I_ARRAY_V (ra0), i0, scm_call_1 (proc, h1.impl->vref (SCM_I_ARRAY_V (ra1), i1))); + for (; n--; i0 += inc0, i1 += inc1) + h0.impl->vset (ra0, i0, scm_call_1 (proc, h1.impl->vref (ra1, i1))); else { ras = scm_vector (ras); - for (; i0 < i0end; i0 += inc0, i1 += inc1, ++i) + for (; n--; i0 += inc0, i1 += inc1, ++i) { SCM args = SCM_EOL; unsigned long k; for (k = scm_c_vector_length (ras); k--;) args = scm_cons (AREF (scm_c_vector_ref (ras, k), i), args); - h0.impl->vset (SCM_I_ARRAY_V (ra0), i0, scm_apply_1 (proc, h1.impl->vref (SCM_I_ARRAY_V (ra1), i1), args)); + h0.impl->vset (ra0, i0, scm_apply_1 (proc, h1.impl->vref (ra1, i1), args)); } } scm_array_handle_release (&h1); @@ -729,19 +675,18 @@ rafe (SCM ra0, SCM proc, SCM ras) size_t n = SCM_I_ARRAY_DIMS (ra0)->ubnd - i + 1; scm_t_array_handle h0; - size_t i0, i0end; + size_t i0; ssize_t inc0; scm_array_get_handle (SCM_I_ARRAY_V (ra0), &h0); i0 = SCM_I_ARRAY_BASE (ra0); inc0 = SCM_I_ARRAY_DIMS (ra0)->inc; - i0end = i0 + n*inc0; if (scm_is_null (ras)) - for (; i0 < i0end; i0 += inc0) + for (; n--; i0 += inc0) scm_call_1 (proc, h0.impl->vref (SCM_I_ARRAY_V (ra0), i0)); else { ras = scm_vector (ras); - for (; i0 < i0end; i0 += inc0, ++i) + for (; n--; i0 += inc0, ++i) { SCM args = SCM_EOL; unsigned long k; diff --git a/test-suite/tests/ramap.test b/test-suite/tests/ramap.test index 037850582..db9d4e145 100644 --- a/test-suite/tests/ramap.test +++ b/test-suite/tests/ramap.test @@ -103,6 +103,15 @@ (array-copy! a b) (equal? b #(1 2)))) + ;; here both a & b are are unrollable down to the first axis, but the + ;; size mismatch limits unrolling to the last axis only. + + (pass-if "matching behavior III" + (let ((a #3(((1 2) (3 4)) ((5 6) (7 8)))) + (b (make-array 0 2 3 2))) + (array-copy! a b) + (array-equal? b #3(((1 2) (3 4) (0 0)) ((5 6) (7 8) (0 0)))))) + (pass-if "rank 2" (let ((a #2((1 2) (3 4))) (b (make-array 0 2 2)) @@ -119,6 +128,19 @@ (equal? d #2((1 3) (2 4))) (equal? e #2((1 2) (3 4)))))) + (pass-if "rank 2, discontinuous" + (let ((A #2((0 1) (2 3) (4 5))) + (B #2((10 11) (12 13) (14 15))) + (C #2((20) (21) (22))) + (X (make-array 0 3 5)) + (piece (lambda (X w s) + (make-shared-array + X (lambda (i j) (list i (+ j s))) 3 w)))) + (array-map! A (piece X 2 0)) + (array-map! B (piece X 2 2)) + (array-map! C (piece X 1 4)) + (and (array-equal? X #2((0 1 10 11 20) (2 3 12 13 21) (4 5 14 15 22)))))) + (pass-if "rank 1" (let* ((a #2((1 2) (3 4))) (b (make-shared-array a (lambda (j) (list 1 j)) 2)) @@ -235,7 +257,26 @@ (pass-if "1+" (let ((a (make-array #f 5))) (array-map! a 1+ (make-array 123 5)) - (equal? a (make-array 124 5))))) + (equal? a (make-array 124 5)))) + + (pass-if "rank 0" + (let ((a #0(99)) + (b (make-array 0))) + (array-map! b values a) + (equal? b #0(99)))) + + (pass-if "rank 2, discontinuous" + (let ((A #2((0 1) (2 3) (4 5))) + (B #2((10 11) (12 13) (14 15))) + (C #2((20) (21) (22))) + (X (make-array 0 3 5)) + (piece (lambda (X w s) + (make-shared-array + X (lambda (i j) (list i (+ j s))) 3 w)))) + (array-map! (piece X 2 0) values A) + (array-map! (piece X 2 2) values B) + (array-map! (piece X 1 4) values C) + (and (array-equal? X #2((0 1 10 11 20) (2 3 12 13 21) (4 5 14 15 22))))))) (with-test-prefix "two sources" @@ -337,7 +378,16 @@ (let ((a #(1 2 3)) (b (make-array 0 2))) (array-map! b values a) - (equal? b #(1 2))))) + (equal? b #(1 2)))) + + ;; here both a & b are are unrollable down to the first axis, but the + ;; size mismatch limits unrolling to the last axis only. + + (pass-if "matching behavior III" + (let ((a #3(((1 2) (3 4) (5 6)) ((7 8) (9 10) (11 12)))) + (b (make-array 0 2 2 2))) + (array-map! b values a) + (array-equal? b #3(((1 2) (3 4)) ((7 8) (9 10))))))) ;;; ;;; array-for-each @@ -346,6 +396,14 @@ (with-test-prefix "array-for-each" (with-test-prefix "1 source" + (pass-if-equal "rank 0" + '(99) + (let* ((a #0(99)) + (l '()) + (p (lambda (x) (set! l (cons x l))))) + (array-for-each p a) + l)) + (pass-if-equal "noncompact array" '(3 2 1 0) (let* ((a #2((0 1) (2 3)))