diff --git a/module/language/cps/contification.scm b/module/language/cps/contification.scm index 970432adf..aa162e021 100644 --- a/module/language/cps/contification.scm +++ b/module/language/cps/contification.scm @@ -30,7 +30,7 @@ (define-module (language cps contification) #:use-module (ice-9 match) - #:use-module ((srfi srfi-1) #:select (concatenate)) + #:use-module ((srfi srfi-1) #:select (concatenate filter-map)) #:use-module (srfi srfi-26) #:use-module (language cps) #:use-module (language cps dfg) @@ -49,8 +49,8 @@ (set! call-substs (acons sym (map cons arities body-ks) call-substs))) (define (subst-return! old-tail new-tail) (set! cont-substs (acons old-tail new-tail cont-substs))) - (define (elide-function! k) - (set! fun-elisions (cons k fun-elisions))) + (define (elide-function! k cont) + (set! fun-elisions (acons k cont fun-elisions))) (define (splice-conts! scope conts) (hashq-set! cont-splices scope (append conts (hashq-ref cont-splices scope '())))) @@ -230,7 +230,7 @@ (if (and=> (bound-symbol k) (lambda (sym) (contify-fun term-k sym self tail-k arity body))) - (elide-function! k) + (elide-function! k (lookup-cont k cont-table)) (visit-fun exp))) (_ #t))))) @@ -276,10 +276,10 @@ (($ $letrec names syms funs body) ($letrec names syms funs ,(lp body))) (($ $letk conts* body) - ($letk ,(append conts* (map visit-cont cont)) + ($letk ,(append conts* (filter-map visit-cont cont)) ,body)) (body - ($letk ,(map visit-cont cont) + ($letk ,(filter-map visit-cont cont) ,body))))))) (define (visit-fun term) (rewrite-cps-exp term @@ -287,9 +287,9 @@ ($fun meta free ,(visit-cont body))))) (define (visit-cont cont) (rewrite-cps-cont cont - (($ $cont (and k (? (cut memq <> fun-elisions))) src - ($ $kargs (_) (_) body)) - (k src ($kargs () () ,(visit-term body k)))) + (($ $cont (? (cut assq <> fun-elisions))) + ;; This cont gets inlined in place of the $fun. + ,#f) (($ $cont sym src ($ $kargs names syms body)) (sym src ($kargs names syms ,(visit-term body sym)))) (($ $cont sym src ($ $kentry self tail clauses)) @@ -312,10 +312,10 @@ (($ $letrec names syms funs body) ($letrec names syms funs ,(lp body))) (($ $letk conts* body) - ($letk ,(append conts* (map visit-cont conts)) + ($letk ,(append conts* (filter-map visit-cont conts)) ,body)) (body - ($letk ,(map visit-cont conts) + ($letk ,(filter-map visit-cont conts) ,body))))) (($ $letrec names syms funs body) (rewrite-cps-term (filter (match-lambda @@ -329,10 +329,13 @@ term-k (match exp (($ $fun) - (if (memq k fun-elisions) - (build-cps-term - ($continue k ($values ()))) - (continue k (visit-fun exp)))) + (cond + ((assq-ref fun-elisions k) + => (match-lambda + (($ $kargs (_) (_) body) + (visit-term body k)))) + (else + (continue k (visit-fun exp))))) (($ $call proc args) (or (contify-call proc args) (continue k exp)))