diff --git a/module/language/cps/reify-primitives.scm b/module/language/cps/reify-primitives.scm index 0823584d4..a473f954f 100644 --- a/module/language/cps/reify-primitives.scm +++ b/module/language/cps/reify-primitives.scm @@ -143,6 +143,14 @@ ($ $continue k src ($ $primcall 'call-thunk/no-inline #f (proc)))) (with-cps cps (setk label ($kargs names vars ($continue k src ($call proc ())))))) + (($ $kargs names vars + ($ $continue k src ($ $primcall 'mul/immediate b (a)))) + (with-cps cps + (letv b*) + (letk kb ($kargs ('b) (b*) + ($continue k src ($primcall 'mul #f (a b*))))) + (setk label ($kargs names vars + ($continue kb src ($const b)))))) (($ $kargs names vars ($ $continue k src ($ $primcall name param args))) (cond ((or (prim-instruction name) (branching-primitive? name)) diff --git a/module/language/cps/type-fold.scm b/module/language/cps/type-fold.scm index 1e5800968..bf016ec83 100644 --- a/module/language/cps/type-fold.scm +++ b/module/language/cps/type-fold.scm @@ -161,6 +161,68 @@ (define-branch-folder-alias s64-= =) + + +;; Convert e.g. rsh to rsh/immediate. + +(define *primcall-macro-reducers* (make-hash-table)) + +(define-syntax-rule (define-primcall-macro-reducer name f) + (hashq-set! *primcall-macro-reducers* 'name f)) + +(define-syntax-rule (define-unary-primcall-macro-reducer (name cps k src + arg type min max) + body ...) + (define-primcall-macro-reducer name + (lambda (cps k src param arg type min max) + body ...))) + +(define-syntax-rule (define-binary-primcall-macro-reducer + (name cps k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + body ...) + (define-primcall-macro-reducer name + (lambda (cps k src param arg0 type0 min0 max0 arg1 type1 min1 max1) + body ...))) + +(define-binary-primcall-macro-reducer (mul cps k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + (cond + ((and (type<=? type0 &exact-integer) (= min0 max0)) + (with-cps cps + (build-term + ($continue k src ($primcall 'mul/immediate min0 (arg1)))))) + ((and (type<=? type1 &exact-integer) (= min1 max1)) + (with-cps cps + (build-term + ($continue k src ($primcall 'mul/immediate min1 (arg0)))))) + (else + (with-cps cps #f)))) + +(define-binary-primcall-macro-reducer (lsh cps k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + (cond + ((= min1 max1) + (with-cps cps + (build-term + ($continue k src ($primcall 'lsh/immediate min1 (arg0)))))) + (else + (with-cps cps #f)))) + +(define-binary-primcall-macro-reducer (rsh cps k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + (cond + ((= min1 max1) + (with-cps cps + (build-term + ($continue k src ($primcall 'rsh/immediate min1 (arg0)))))) + (else + (with-cps cps #f)))) + ;; Strength reduction. @@ -170,14 +232,14 @@ (define-syntax-rule (define-primcall-reducer name f) (hashq-set! *primcall-reducers* 'name f)) -(define-syntax-rule (define-unary-primcall-reducer (name cps k src +(define-syntax-rule (define-unary-primcall-reducer (name cps k src param arg type min max) body ...) (define-primcall-reducer name (lambda (cps k src param arg type min max) body ...))) -(define-syntax-rule (define-binary-primcall-reducer (name cps k src +(define-syntax-rule (define-binary-primcall-reducer (name cps k src param arg0 type0 min0 max0 arg1 type1 min1 max1) body ...) @@ -185,62 +247,42 @@ (lambda (cps k src param arg0 type0 min0 max0 arg1 type1 min1 max1) body ...))) -(define-binary-primcall-reducer (mul cps k src - arg0 type0 min0 max0 - arg1 type1 min1 max1) - (define (fail) (with-cps cps #f)) - (define (negate arg) +(define-unary-primcall-reducer (mul/immediate cps k src constant + arg type min max) + (cond + ((not (type<=? type &number)) + (with-cps cps #f)) + ((eqv? constant -1) + ;; (* arg -1) -> (- 0 arg) (with-cps cps ($ (with-cps-constants ((zero 0)) (build-term ($continue k src ($primcall 'sub #f (zero arg)))))))) - (define (zero) + ((and (eqv? constant 0) + (type<=? type (logior &exact-integer &fraction))) + ;; (* arg 0) -> 0 if arg is exact (with-cps cps (build-term ($continue k src ($const 0))))) - (define (identity arg) + ((eqv? constant 1) + ;; (* arg 1) -> arg (with-cps cps (build-term ($continue k src ($values (arg)))))) - (define (double arg) + ((eqv? constant 2) + ;; (* arg 2) -> (+ arg arg) (with-cps cps (build-term ($continue k src ($primcall 'add #f (arg arg)))))) - (define (power-of-two constant arg) + ((and (type<=? type &exact-integer) + (positive? constant) + (zero? (logand constant (1- constant)))) + ;; (* arg power-of-2) -> (lsh arg (log2 power-of-2)) (let ((n (let lp ((bits 0) (constant constant)) (if (= constant 1) bits (lp (1+ bits) (ash constant -1)))))) (with-cps cps (build-term ($continue k src ($primcall 'lsh/immediate n (arg))))))) - (define (mul/constant constant constant-type arg arg-type) - (cond - ((not (or (type<=? constant-type &exact-integer) - (= constant-type arg-type))) - (fail)) - ((eqv? constant -1) - ;; (* arg -1) -> (- 0 arg) - (negate arg)) - ((eqv? constant 0) - ;; (* arg 0) -> 0 if arg is exact - (and (type<=? constant-type &exact-integer) - (type<=? arg-type (logior &exact-integer &fraction)) - (zero))) - ((eqv? constant 1) - ;; (* arg 1) -> arg - (identity arg)) - ((eqv? constant 2) - ;; (* arg 2) -> (+ arg arg) - (double arg)) - ((and (type<=? (logior constant-type arg-type) &exact-integer) - (positive? constant) - (zero? (logand constant (1- constant)))) - ;; (* arg power-of-2) -> (ash arg (log2 power-of-2)) - (power-of-two constant arg)) - (else - (fail)))) - (cond - ((logtest (logior type0 type1) (lognot &number)) (fail)) - ((= min0 max0) (mul/constant min0 type0 arg1 type1)) - ((= min1 max1) (mul/constant min1 type1 arg0 type0)) - (else (fail)))) + (else + (with-cps cps #f)))) -(define-binary-primcall-reducer (logbit? cps k src +(define-binary-primcall-reducer (logbit? cps k src param arg0 type0 min0 max0 arg1 type1 min1 max1) ;; FIXME: Use an unboxed number for the mask instead of a fixnum. @@ -252,8 +294,12 @@ ($continue kmask src ($const (ash 1 min0))))) (with-cps cps ($ (with-cps-constants ((one 1)) + (letv n) + (letk kn ($kargs ('n) (n) + ($continue kmask src + ($primcall 'lsh #f (one n))))) (build-term - ($continue kmask src ($primcall 'ash #f (one arg0))))))))) + ($continue kn src ($primcall 'untag-fixnum #f (arg0))))))))) (with-cps cps (letv mask test) (letk kt ($kargs () () @@ -272,34 +318,33 @@ ($ (compute-mask kmask src)))) ;; Hairiness because we are converting from a primcall with unknown ;; arity to a branching primcall. - (let ((positive-fixnum-bits (- (* (target-word-size) 8) 3))) - (if (and (type<=? type0 &exact-integer) - (<= 0 min0 positive-fixnum-bits) - (<= 0 max0 positive-fixnum-bits)) - (match (intmap-ref cps k) - (($ $kreceive arity kargs) - (match arity - (($ $arity (_) () (not #f) () #f) - (with-cps cps - (letv bool) - (let$ body (with-cps-constants ((nil '())) - (build-term - ($continue kargs src ($values (bool nil)))))) - (letk kbool ($kargs (#f) (bool) ,body)) - ($ (convert-to-logtest kbool)))) - (_ - (with-cps cps - (letv bool) - (letk kbool ($kargs (#f) (bool) - ($continue k src ($primcall 'values #f (bool))))) - ($ (convert-to-logtest kbool)))))) - (($ $ktail) - (with-cps cps - (letv bool) - (letk kbool ($kargs (#f) (bool) - ($continue k src ($values (bool))))) - ($ (convert-to-logtest kbool))))) - (with-cps cps #f)))) + (if (and (type<=? type0 &exact-integer) + (<= 0 min0 (target-most-positive-fixnum)) + (<= 0 max0 (target-most-positive-fixnum))) + (match (intmap-ref cps k) + (($ $kreceive arity kargs) + (match arity + (($ $arity (_) () (not #f) () #f) + (with-cps cps + (letv bool) + (let$ body (with-cps-constants ((nil '())) + (build-term + ($continue kargs src ($values (bool nil)))))) + (letk kbool ($kargs (#f) (bool) ,body)) + ($ (convert-to-logtest kbool)))) + (_ + (with-cps cps + (letv bool) + (letk kbool ($kargs (#f) (bool) + ($continue k src ($primcall 'values #f (bool))))) + ($ (convert-to-logtest kbool)))))) + (($ $ktail) + (with-cps cps + (letv bool) + (letk kbool ($kargs (#f) (bool) + ($continue k src ($values (bool))))) + ($ (convert-to-logtest kbool))))) + (with-cps cps #f))) @@ -343,35 +388,43 @@ (setk label ($kargs names vars ($continue k* src ($primcall name param args)))))))))) + (define (transform-primcall f cps label names vars k src name param args) + (and f + (match args + ((arg0) + (call-with-values (lambda () (lookup-pre-type types label arg0)) + (lambda (type0 min0 max0) + (call-with-values (lambda () + (f cps k src param arg0 type0 min0 max0)) + (lambda (cps term) + (and term + (with-cps cps + (setk label ($kargs names vars ,term))))))))) + ((arg0 arg1) + (call-with-values (lambda () (lookup-pre-type types label arg0)) + (lambda (type0 min0 max0) + (call-with-values (lambda () (lookup-pre-type types label arg1)) + (lambda (type1 min1 max1) + (call-with-values (lambda () + (f cps k src param arg0 type0 min0 max0 + arg1 type1 min1 max1)) + (lambda (cps term) + (and term + (with-cps cps + (setk label ($kargs names vars ,term))))))))))) + (_ #f)))) (define (reduce-primcall cps label names vars k src name param args) - (and=> - (hashq-ref *primcall-reducers* name) - (lambda (reducer) - (match args - ((arg0) - (call-with-values (lambda () (lookup-pre-type types label arg0)) - (lambda (type0 min0 max0) - (call-with-values (lambda () - (reducer cps k src param - arg0 type0 min0 max0)) - (lambda (cps term) - (and term - (with-cps cps - (setk label ($kargs names vars ,term))))))))) - ((arg0 arg1) - (call-with-values (lambda () (lookup-pre-type types label arg0)) - (lambda (type0 min0 max0) - (call-with-values (lambda () (lookup-pre-type types label arg1)) - (lambda (type1 min1 max1) - (call-with-values (lambda () - (reducer cps k src param - arg0 type0 min0 max0 - arg1 type1 min1 max1)) - (lambda (cps term) - (and term - (with-cps cps - (setk label ($kargs names vars ,term))))))))))) - (_ #f))))) + (cond + ((transform-primcall (hashq-ref *primcall-macro-reducers* name) + cps label names vars k src name param args) + => (lambda (cps) + (match (intmap-ref cps label) + (($ $kargs names vars + ($ $continue k src ($ $primcall name param args))) + (reduce-primcall cps label names vars k src name param args))))) + ((transform-primcall (hashq-ref *primcall-reducers* name) + cps label names vars k src name param args)) + (else cps))) (define (fold-unary-branch cps label names vars kf kt src name param arg) (and=> (hashq-ref *branch-folders* name) @@ -412,11 +465,9 @@ (match (intmap-ref cps k) (($ $kargs (_) (def)) (or (fold-primcall cps label names vars k src name param args def) - (reduce-primcall cps label names vars k src name param args) - cps)) + (reduce-primcall cps label names vars k src name param args))) (_ - (or (reduce-primcall cps label names vars k src name param args) - cps)))) + (reduce-primcall cps label names vars k src name param args)))) (($ $branch kt ($ $primcall name param args)) ;; We might be able to fold primcalls that branch. (match args