From cc0c7336151923dea880b27ea6f4904cb8cc44bf Mon Sep 17 00:00:00 2001 From: Matthew Butterick Date: Sat, 14 Feb 2015 13:01:33 -0800 Subject: [PATCH] working --- quad/ocm-typed-test.rkt | 94 ++++++++++++------------ quad/ocm-typed.rkt | 158 ++++++++++++---------------------------- 2 files changed, 93 insertions(+), 159 deletions(-) diff --git a/quad/ocm-typed-test.rkt b/quad/ocm-typed-test.rkt index e2a2ecc0..4495156e 100644 --- a/quad/ocm-typed-test.rkt +++ b/quad/ocm-typed-test.rkt @@ -4,23 +4,23 @@ (require math) (define m0 (matrix ((25.0 42.0 57.0 78.0 90.0 103.0 123.0 142.0 151.0) - (21.0 35.0 48.0 65.0 76.0 85.0 105.0 123.0 130.0) - (13.0 26.0 35.0 51.0 58.0 67.0 86.0 100.0 104.0) - (10.0 20.0 28.0 42.0 48.0 56.0 75.0 86.0 88.0) - (20.0 29.0 33.0 44.0 49.0 55.0 73.0 82.0 80.0) - (13.0 21.0 24.0 35.0 39.0 44.0 59.0 65.0 59.0) - (19.0 25.0 28.0 38.0 42.0 44.0 57.0 61.0 52.0) - (35.0 37.0 40.0 48.0 48.0 49.0 62.0 62.0 49.0) - (37.0 36.0 37.0 42.0 39.0 39.0 51.0 50.0 37.0) - (41.0 39.0 37.0 42.0 35.0 33.0 44.0 43.0 29.0) - (58.0 56.0 54.0 55.0 47.0 41.0 50.0 47.0 29.0) - (66.0 64.0 61.0 61.0 51.0 44.0 52.0 45.0 24.0) - (82.0 76.0 72.0 70.0 56.0 49.0 55.0 46.0 23.0) - (99.0 91.0 83.0 80.0 63.0 56.0 59.0 46.0 20.0) - (124.0 116.0 107.0 100.0 80.0 71.0 72.0 58.0 28.0) - (133.0 125.0 113.0 106.0 86.0 75.0 74.0 59.0 25.0) - (156.0 146.0 131.0 120.0 97.0 84.0 80.0 65.0 31.0) - (178.0 164.0 146.0 135.0 110.0 96.0 92.0 73.0 39.0)))) + (21.0 35.0 48.0 65.0 76.0 85.0 105.0 123.0 130.0) + (13.0 26.0 35.0 51.0 58.0 67.0 86.0 100.0 104.0) + (10.0 20.0 28.0 42.0 48.0 56.0 75.0 86.0 88.0) + (20.0 29.0 33.0 44.0 49.0 55.0 73.0 82.0 80.0) + (13.0 21.0 24.0 35.0 39.0 44.0 59.0 65.0 59.0) + (19.0 25.0 28.0 38.0 42.0 44.0 57.0 61.0 52.0) + (35.0 37.0 40.0 48.0 48.0 49.0 62.0 62.0 49.0) + (37.0 36.0 37.0 42.0 39.0 39.0 51.0 50.0 37.0) + (41.0 39.0 37.0 42.0 35.0 33.0 44.0 43.0 29.0) + (58.0 56.0 54.0 55.0 47.0 41.0 50.0 47.0 29.0) + (66.0 64.0 61.0 61.0 51.0 44.0 52.0 45.0 24.0) + (82.0 76.0 72.0 70.0 56.0 49.0 55.0 46.0 23.0) + (99.0 91.0 83.0 80.0 63.0 56.0 59.0 46.0 20.0) + (124.0 116.0 107.0 100.0 80.0 71.0 72.0 58.0 28.0) + (133.0 125.0 113.0 106.0 86.0 75.0 74.0 59.0 25.0) + (156.0 146.0 131.0 120.0 97.0 84.0 80.0 65.0 31.0) + (178.0 164.0 146.0 135.0 110.0 96.0 92.0 73.0 39.0)))) (define m (matrix->list* m0)) (define m2 (matrix->list* (matrix-transpose m0))) @@ -35,15 +35,15 @@ ;; proc must return a value even for out-of-bounds i and j (: simple-proc Matrix-Proc-Type) (define (simple-proc i j) (cast (fl (with-handlers [(exn:fail? (λ(exn) (* -1 i)))] - ((inst list-ref Value-Type) ((inst list-ref (Listof Value-Type)) m i) j))) Value-Type)) + ((inst list-ref Value-Type) ((inst list-ref (Listof Value-Type)) m i) j))) Value-Type)) (: simple-proc2 Matrix-Proc-Type) (define (simple-proc2 i j) (cast (fl (with-handlers [(exn:fail? (λ(exn) (* -1 i)))] - ((inst list-ref Value-Type) ((inst list-ref (Listof Value-Type)) m2 i) j))) Value-Type)) + ((inst list-ref Value-Type) ((inst list-ref (Listof Value-Type)) m2 i) j))) Value-Type)) (check-equal? (simple-proc 0 2) 57.0) ; 0th row, 2nd col (check-equal? (simple-proc2 2 0) 57.0) ; flipped -(define row-indices (cast (list->vector (range (length m))) (Vectorof Nonnegative-Integer))) -(define col-indices (cast (list->vector (range (length (car m)))) (Vectorof Nonnegative-Integer))) +(define row-indices (cast (list->vector (range (length m))) (Vectorof Index-Type))) +(define col-indices (cast (list->vector (range (length (car m)))) (Vectorof Index-Type))) (define result (concave-minima row-indices col-indices simple-proc simple-entry->value)) @@ -53,31 +53,33 @@ (list (hash-ref h 'value) (hash-ref h 'row-idx))) '((10.0 3) (20.0 3) (24.0 5) (35.0 5) (35.0 9) (33.0 9) (44.0 9) (43.0 9) (20.0 13))) ; checked against SMAWK.py -#| - (define o (make-ocm simple-proc simple-entry->value)) - (check-equal? - (for/list : (Listof (Pairof Integer Any)) ([j (in-vector col-indices)]) - (define h (hash-ref result j)) - (list (hash-ref h 'value) (hash-ref h 'row-idx))) - '((10 3) (20 3) (24 5) (35 5) (35 9) (33 9) (44 9) (43 9) (20 13))) ; checked against SMAWK.py - (check-equal? - (for/list : (Listof (Pairof Integer Any)) ([j (in-vector col-indices)]) - (list (min-value o j) (min-index o j))) - '((0 none) (42 0) (48 1) (51 2) (48 3) (55 4) (59 5) (61 6) (49 7))) ; checked against SMAWK.py - - (define o2 (make-ocm simple-proc2)) - (define row-indices2 (list->vector (range (length m2)))) - (define col-indices2 (list->vector (range (length (car m2))))) - (define result2 (concave-minima row-indices2 col-indices2 simple-proc2 identity)) - (check-equal? - (for/list : (Listof (Pairof Integer Any)) ([j (in-vector col-indices2)]) - (define h (hash-ref result2 j)) - (list (hash-ref h 'value) (hash-ref h 'row-idx))) - '((25 0) (21 0) (13 0) (10 0) (20 0) (13 0) (19 0) (35 0) (36 1) (29 8) (29 8) (24 8) (23 8) (20 8) (28 8) (25 8) (31 8) (39 8))) ; checked against SMAWK.py +(check-equal? + (for/list : (Listof (List Value-Type Index-Type)) ([j (in-vector col-indices)]) + (define h (cast (hash-ref result j) (HashTable Symbol Any))) + (list (cast (hash-ref h 'value) Value-Type) (cast (hash-ref h 'row-idx) Index-Type))) + '((10.0 3) (20.0 3) (24.0 5) (35.0 5) (35.0 9) (33.0 9) (44.0 9) (43.0 9) (20.0 13))) ; checked against SMAWK.py + + + +(define o (make-ocm simple-proc simple-entry->value)) + (check-equal? - (for/list : (Listof (Pairof Integer Any)) ([j (in-vector col-indices2)]) - (list (min-value o2 j) (min-index o2 j))) - '((0 none) (21 0) (13 0) (10 0) (20 0) (13 0) (19 0) (35 0) (36 1) (29 8) (-9 9) (-10 10) (-11 11) (-12 12) (-13 13) (-14 14) (-15 15) (-16 16))) ; checked against SMAWK.py -|# + (for/list : (Listof (List Value-Type (U Index-Type No-Value-Type))) ([j (in-vector col-indices)]) + (list (cast (ocm-min-value o j) Value-Type) (ocm-min-index o j))) + '((0.0 none) (42.0 0) (48.0 1) (51.0 2) (48.0 3) (55.0 4) (59.0 5) (61.0 6) (49.0 7))) ; checked against SMAWK.py + (define row-indices2 (cast (list->vector (range (length m2))) (Vectorof Index-Type))) + (define col-indices2 (cast (list->vector (range (length (car m2)))) (Vectorof Index-Type))) + (define result2 (concave-minima row-indices2 col-indices2 simple-proc2 simple-entry->value)) + (check-equal? + (for/list : (Listof (List Value-Type Index-Type)) ([j (in-vector col-indices2)]) + (define h (cast (hash-ref result2 j) (HashTable Symbol (U Index-Type Value-Type)))) + (list (cast (hash-ref h 'value) Value-Type) (cast (hash-ref h 'row-idx) Index-Type))) + '((25.0 0) (21.0 0) (13.0 0) (10.0 0) (20.0 0) (13.0 0) (19.0 0) (35.0 0) (36.0 1) (29.0 8) (29.0 8) (24.0 8) (23.0 8) (20.0 8) (28.0 8) (25.0 8) (31.0 8) (39.0 8))) ; checked against SMAWK.py + +(define o2 (make-ocm simple-proc2 simple-entry->value)) + (check-equal? + (for/list : (Listof (List Value-Type (U Index-Type No-Value-Type))) ([j (in-vector col-indices2)]) + (list (cast (ocm-min-value o2 j) Value-Type) (ocm-min-index o2 j))) + '((0.0 none) (21.0 0) (13.0 0) (10.0 0) (20.0 0) (13.0 0) (19.0 0) (35.0 0) (36.0 1) (29.0 8) (-9.0 9) (-10.0 10) (-11.0 11) (-12.0 12) (-13.0 13) (-14.0 14) (-15.0 15) (-16.0 16))) ; checked against SMAWK.py \ No newline at end of file diff --git a/quad/ocm-typed.rkt b/quad/ocm-typed.rkt index 9995edee..76efbe1c 100644 --- a/quad/ocm-typed.rkt +++ b/quad/ocm-typed.rkt @@ -3,7 +3,7 @@ (require racket/list sugar/debug racket/function racket/vector "logger-typed.rkt") (define-logger ocm) -(provide smawky? Entry->Value-Type Value-Type Index-Type Matrix-Proc-Type make-ocm reduce reduce2 concave-minima (prefix-out ocm- (combine-out min-value min-index))) +(provide smawky? Entry->Value-Type Value-Type No-Value-Type Index-Type Matrix-Proc-Type make-ocm reduce reduce2 concave-minima (prefix-out ocm- (combine-out min-value min-index))) (: select-elements ((Listof Any) (Listof Index-Type) . -> . (Listof Any))) (define (select-elements xs is) @@ -33,6 +33,13 @@ (define-syntax-rule (vector-append-item xs value) ((inst vector-append Any) xs (vector value))) +(define-syntax-rule (vector-append-value xs value) + ((inst vector-append Value-Type) xs (vector value))) + +(define-syntax-rule (vector-append-index xs value) + ((inst vector-append (U Index-Type No-Value-Type)) xs (vector value))) + + (: vector-set (All (a) ((Vectorof a) Integer a -> (Vectorof a)))) (define (vector-set vec idx val) (vector-set! vec idx val) @@ -193,131 +200,56 @@ (define-type Finished-Value-Type Index-Type) (define-type Matrix-Proc-Type (Index-Type Index-Type . -> . Value-Type)) (define-type Entry->Value-Type (Any . -> . Value-Type)) -;(define-type OCM-Type (Vector (Vectorof Value-Type) (Vectorof (U Index-Type No-Value-Type)) Finished-Value-Type Matrix-Proc-Type Entry->Value-Type Index-Type Index-Type)) -(define-type OCM-Type (Vector Any Any Any Any Any Any)) -(define o:min-values 0) -(define o:min-row-indices 1) -(define o:finished 2) -(define o:matrix-proc 3) -(define o:entry->value 4) -(define o:base 5) -(define o:tentative 6) - -(: ocm-matrix-proc (OCM-Type . -> . Matrix-Proc-Type)) -(define (ocm-matrix-proc o) - (cast (vector-ref o o:matrix-proc) Matrix-Proc-Type)) - -(: ocm-set-matrix-proc (OCM-Type Matrix-Proc-Type . -> . Void)) -(define (ocm-set-matrix-proc o proc) - (vector-set! o o:matrix-proc (cast proc Matrix-Proc-Type))) - - -(: ocm-entry->value (OCM-Type . -> . Entry->Value-Type)) -(define (ocm-entry->value o) - (cast (vector-ref o o:entry->value) Entry->Value-Type)) - -(: ocm-set-entry->value (OCM-Type Entry->Value-Type . -> . Void)) -(define (ocm-set-entry->value o proc) - (vector-set! o o:entry->value (cast proc Entry->Value-Type))) - -(: ocm-finished (OCM-Type . -> . Finished-Value-Type)) -(define (ocm-finished o) - (cast (vector-ref o o:finished) Finished-Value-Type)) - -(: ocm-set-finished (OCM-Type Finished-Value-Type . -> . Void)) -(define (ocm-set-finished o v) - (vector-set! o o:finished (cast v Finished-Value-Type))) - -(: ocm-tentative (OCM-Type . -> . Index-Type)) -(define (ocm-tentative o) - (cast (vector-ref o o:tentative) Index-Type)) - -(: ocm-set-tentative (OCM-Type Index-Type . -> . Void)) -(define (ocm-set-tentative o v) - (vector-set! o o:tentative (cast v Index-Type))) - -(: ocm-base (OCM-Type . -> . Index-Type)) -(define (ocm-base o) - (cast (vector-ref o o:tentative) Index-Type)) - -(: ocm-set-base (OCM-Type Index-Type . -> . Void)) -(define (ocm-set-base o v) - (vector-set! o o:tentative (cast v Index-Type))) - - -(: ocm-min-values (OCM-Type . -> . (Vectorof Value-Type))) -(define (ocm-min-values o) - (cast (vector-ref o o:min-values) (Vectorof Value-Type))) - -(: ocm-set-min-values (OCM-Type (Vectorof Value-Type) . -> . Void)) -(define (ocm-set-min-values o vs) - (vector-set! o o:min-values (cast vs (Vectorof Value-Type)))) - - -(: ocm-min-row-indices (OCM-Type . -> . (Vectorof (U Index-Type No-Value-Type)))) -(define (ocm-min-row-indices o) - (cast (vector-ref o o:min-row-indices) (Vectorof (U Index-Type No-Value-Type)))) - -(: ocm-set-min-row-indices (OCM-Type (Vectorof (U Index-Type No-Value-Type)) . -> . Void)) -(define (ocm-set-min-row-indices o vs) - (vector-set! o o:min-row-indices (cast vs (Vectorof (U Index-Type No-Value-Type))))) +(struct $ocm ([min-values : (Vectorof Value-Type)] [min-row-indices : (Vectorof (U Index-Type No-Value-Type))] [finished : Finished-Value-Type] [matrix-proc : Matrix-Proc-Type] [entry->value : Entry->Value-Type] [base : Index-Type] [tentative : Index-Type]) #:transparent #:mutable) +(define-type OCM-Type $ocm) (: make-ocm ((Matrix-Proc-Type Entry->Value-Type) (Initial-Value-Type) . ->* . OCM-Type)) (define (make-ocm matrix-proc entry->value [initial-value 0.0]) (log-ocm-debug "making new ocm") - (define ocm (cast (make-vector 7) OCM-Type)) - (ocm-set-min-values ocm (vector initial-value)) - (ocm-set-min-row-indices ocm (vector no-value)) - (ocm-set-finished ocm 0) - (ocm-set-matrix-proc ocm matrix-proc) - (ocm-set-entry->value ocm entry->value) ; for converting matrix values to an integer - (ocm-set-base ocm 0) - (ocm-set-tentative ocm 0) - ocm) + ($ocm (vector initial-value) (vector no-value) 0 matrix-proc entry->value 0 0)) ;; Return min { Matrix(i,j) | i < j }. (: min-value (OCM-Type Index-Type . -> . Any)) (define (min-value ocm j) - (if (< (cast (ocm-finished ocm) Real) j) + (if (< (cast ($ocm-finished ocm) Real) j) (begin (advance! ocm) (min-value ocm j)) - (vector-ref (ocm-min-values ocm) j))) + (vector-ref ($ocm-min-values ocm) j))) ;; Return argmin { Matrix(i,j) | i < j }. (: min-index (OCM-Type Index-Type . -> . (U Index-Type No-Value-Type))) (define (min-index ocm j) - (if (< (cast (ocm-finished ocm) Real) j) + (if (< (cast ($ocm-finished ocm) Real) j) (begin (advance! ocm) (min-index ocm j)) - ((inst vector-ref (U Index-Type No-Value-Type)) (ocm-min-row-indices ocm) j))) + ((inst vector-ref (U Index-Type No-Value-Type)) ($ocm-min-row-indices ocm) j))) ;; Finish another value,index pair. (: advance! (OCM-Type . -> . Void)) (define (advance! ocm) - (define next (add1 (ocm-finished ocm))) - (log-ocm-debug "advance! ocm to next = ~a" (add1 (ocm-finished ocm))) + (define next (add1 ($ocm-finished ocm))) + (log-ocm-debug "advance! ocm to next = ~a" (add1 ($ocm-finished ocm))) (cond ;; First case: we have already advanced past the previous tentative ;; value. We make a new tentative value by applying ConcaveMinima ;; to the largest square submatrix that fits under the base. - [(> next (ocm-tentative ocm)) - (log-ocm-debug "advance: first case because next (~a) > tentative (~a)" next (ocm-tentative ocm)) - (define rows : (Vectorof Index-Type) (list->vector (range (ocm-base ocm) next))) - (ocm-set-tentative ocm (+ (ocm-finished ocm) (vector-length rows))) - (define cols : (Vectorof Index-Type) (list->vector (range next (add1 (ocm-tentative ocm))))) - (define minima (concave-minima rows cols (ocm-matrix-proc ocm) (ocm-entry->value ocm))) - (error 'stop) + [(> next ($ocm-tentative ocm)) + (log-ocm-debug "advance: first case because next (~a) > tentative (~a)" next ($ocm-tentative ocm)) + (define rows : (Vectorof Index-Type) (list->vector (range ($ocm-base ocm) next))) + (set-$ocm-tentative! ocm (+ ($ocm-finished ocm) (vector-length rows))) + (define cols : (Vectorof Index-Type) (list->vector (range next (add1 ($ocm-tentative ocm))))) + (define minima (concave-minima rows cols ($ocm-matrix-proc ocm) ($ocm-entry->value ocm))) (for ([col (in-vector cols)]) (cond - [(>= col (vector-length (ocm-min-values ocm))) - (ocm-set-min-values ocm (vector-append-item (ocm-min-values ocm) (@ (cast (@ minima col) HashTableTop) 'value))) - (ocm-set-min-row-indices ocm (vector-append-item (ocm-min-row-indices ocm) (@ (cast (@ minima col) HashTableTop) 'row-idx)))] - [(< ((ocm-entry->value ocm) (@ (cast (@ minima col) HashTableTop) 'value)) ((ocm-entry->value ocm) (vector-ref (ocm-min-values ocm) col))) - (ocm-set-min-values ocm ((inst vector-set Index-Type) (ocm-min-values ocm) col (cast (@ (cast (@ minima col) HashTableTop) 'value) Index-Type))) - (ocm-set-min-row-indices ocm ((inst vector-set Index-Type) (ocm-min-row-indices ocm) col (cast (@ (cast (@ minima col) HashTableTop) 'row-idx) Index-Type)))])) + [(>= col (vector-length ($ocm-min-values ocm))) + (set-$ocm-min-values! ocm (vector-append-value ($ocm-min-values ocm) (@ (cast (@ minima col) (HashTable Symbol Value-Type)) 'value))) + (set-$ocm-min-row-indices! ocm (vector-append-index ($ocm-min-row-indices ocm) (@ (cast (@ minima col) (HashTable Symbol Index-Type)) 'row-idx)))] + [(< (($ocm-entry->value ocm) (@ (cast (@ minima col) HashTableTop) 'value)) (($ocm-entry->value ocm) (vector-ref ($ocm-min-values ocm) col))) + (set-$ocm-min-values! ocm ((inst vector-set Value-Type) ($ocm-min-values ocm) col (cast (@ (cast (@ minima col) HashTableTop) 'value) Value-Type))) + (set-$ocm-min-row-indices! ocm ((inst vector-set (U Index-Type No-Value-Type)) ($ocm-min-row-indices ocm) col (cast (@ (cast (@ minima col) HashTableTop) 'row-idx) Index-Type)))])) - (ocm-set-finished ocm next)] + (set-$ocm-finished! ocm next)] [else ;; Second case: the new column minimum is on the diagonal. @@ -325,23 +257,23 @@ ;; so we can clear out all our work from higher rows. ;; As in the fourth case, the loss of tentative is ;; amortized against the increase in base. - (define diag ((ocm-matrix-proc ocm) (sub1 next) next)) + (define diag (($ocm-matrix-proc ocm) (sub1 next) next)) (cond - [(< ((ocm-entry->value ocm) diag) ((ocm-entry->value ocm) (vector-ref (ocm-min-values ocm) next))) + [(< (($ocm-entry->value ocm) diag) (($ocm-entry->value ocm) (vector-ref ($ocm-min-values ocm) next))) (log-ocm-debug "advance: second case because column minimum is on the diagonal") - (ocm-set-min-values ocm (vector-set (ocm-min-values ocm) next diag)) - (ocm-set-min-row-indices ocm (vector-set (ocm-min-row-indices ocm) next (sub1 next))) - (ocm-set-base ocm (sub1 next)) - (ocm-set-tentative ocm next) - (ocm-set-finished ocm next)] + (set-$ocm-min-values! ocm (vector-set ($ocm-min-values ocm) next diag)) + (set-$ocm-min-row-indices! ocm (vector-set ($ocm-min-row-indices ocm) next (sub1 next))) + (set-$ocm-base! ocm (sub1 next)) + (set-$ocm-tentative! ocm next) + (set-$ocm-finished! ocm next)] ;; Third case: row i-1 does not supply a column minimum in ;; any column up to tentative. We simply advance finished ;; while maintaining the invariant. - [(>= ((ocm-entry->value ocm) ((ocm-matrix-proc ocm) (sub1 next) (ocm-tentative ocm))) - ((ocm-entry->value ocm) (vector-ref (ocm-min-values ocm) (ocm-tentative ocm)))) + [(>= (($ocm-entry->value ocm) (($ocm-matrix-proc ocm) (sub1 next) ($ocm-tentative ocm))) + (($ocm-entry->value ocm) (vector-ref ($ocm-min-values ocm) ($ocm-tentative ocm)))) (log-ocm-debug "advance: third case because row i-1 does not suppply a column minimum") - (ocm-set-finished ocm next)] + (set-$ocm-finished! ocm next)] ;; Fourth and final case: a new column minimum at self._tentative. ;; This allows us to make progress by incorporating rows @@ -351,14 +283,14 @@ ;; this step) can be amortized against the increase in base. [else (log-ocm-debug "advance: fourth case because new column minimum") - (ocm-set-base ocm (sub1 next)) - (ocm-set-tentative ocm next) - (ocm-set-finished ocm next)])])) + (set-$ocm-base! ocm (sub1 next)) + (set-$ocm-tentative! ocm next) + (set-$ocm-finished! ocm next)])])) (: print (OCM-Type . -> . Void)) (define (print ocm) - (displayln (ocm-min-values ocm)) - (displayln (ocm-min-row-indices ocm))) + (displayln ($ocm-min-values ocm)) + (displayln ($ocm-min-row-indices ocm))) (: smawky? ((Listof (Listof Real)) . -> . Boolean))