diff --git a/module/language/cps/intset.scm b/module/language/cps/intset.scm index e8e6df2d6..b235c61b0 100644 --- a/module/language/cps/intset.scm +++ b/module/language/cps/intset.scm @@ -35,7 +35,8 @@ intset-ref intset-next intset-union - intset-intersect)) + intset-intersect + intset-subtract)) (define-syntax-rule (define-inline name val) (define-syntax name (identifier-syntax val))) @@ -454,3 +455,81 @@ ((eq? root a-root) a) ((eq? root b-root) b) (else (make-intset/prune a-min a-shift root))))))))) + +(define (intset-subtract a b) + (define tmp (new-leaf)) + ;; Intersect leaves. + (define (subtract-leaves a b) + (logand a (lognot b))) + ;; Subtract B from A starting at index I; the result will be fresh. + (define (subtract-branches/fresh shift a b i fresh) + (let lp ((i 0)) + (cond + ((< i *branch-size*) + (let* ((a-child (vector-ref a i)) + (b-child (vector-ref b i))) + (vector-set! fresh i (subtract-nodes shift a-child b-child)) + (lp (1+ i)))) + ((branch-empty? fresh) #f) + (else fresh)))) + ;; Subtract B from A. The result may be eq? to A. + (define (subtract-branches shift a b) + (let lp ((i 0)) + (cond + ((< i *branch-size*) + (let* ((a-child (vector-ref a i)) + (b-child (vector-ref b i))) + (let ((child (subtract-nodes shift a-child b-child))) + (cond + ((eq? a-child child) + (lp (1+ i))) + (else + (let ((result (clone-branch-and-set a i child))) + (subtract-branches/fresh shift a b (1+ i) result))))))) + (else a)))) + (define (subtract-nodes shift a-node b-node) + (cond + ((or (not a-node) (not b-node)) a-node) + ((eq? a-node b-node) #f) + ((= shift *leaf-bits*) (subtract-leaves a-node b-node)) + (else (subtract-branches (- shift *branch-bits*) a-node b-node)))) + + (match (cons a b) + ((($ a-min a-shift a-root) . ($ b-min b-shift b-root)) + (define (return root) + (cond + ((eq? root a-root) a) + (else (make-intset/prune a-min a-shift root)))) + (cond + ((<= a-shift b-shift) + (let lp ((b-min b-min) (b-shift b-shift) (b-root b-root)) + (if (= a-shift b-shift) + (if (= a-min b-min) + (return (subtract-nodes a-shift a-root b-root)) + a) + (let* ((b-shift (- b-shift *branch-bits*)) + (b-idx (ash (- a-min b-min) (- b-shift))) + (b-min (+ b-min (ash b-idx b-shift))) + (b-root (and b-root + (<= 0 b-idx) + (< b-idx *branch-size*) + (vector-ref b-root b-idx)))) + (lp b-min b-shift b-root))))) + (else + (return + (let lp ((a-min a-min) (a-shift a-shift) (a-root a-root)) + (if (= a-shift b-shift) + (if (= a-min b-min) + (subtract-nodes a-shift a-root b-root) + a-root) + (let* ((a-shift (- a-shift *branch-bits*)) + (a-idx (ash (- b-min a-min) (- a-shift))) + (a-min (+ a-min (ash a-idx a-shift))) + (old (and a-root + (<= 0 a-idx) + (< a-idx *branch-size*) + (vector-ref a-root a-idx))) + (new (lp a-min a-shift old))) + (if (eq? old new) + a-root + (clone-branch-and-set a-root a-idx new)))))))))))