Clojure与Numpy的matrix乘法
我正在使用Clojure中的一个应用程序,需要将大型matrix乘法,并且遇到一些性能问题,与相同的Numpy版本相比。 Numpy似乎能够在一秒之内通过转置乘以一个1,000,000×23的matrix,而相同的clojure代码则需要六分钟。 (我可以打印出Numpy生成的matrix,所以它肯定是评估一切)。
我在这个Clojure代码中做了一些非常错误的事情吗? 我可以尝试模仿Numpy的一些技巧吗?
这是python:
import numpy as np def test_my_mult(n): A = np.random.rand(n*23).reshape(n,23) At = AT t0 = time.time() res = np.dot(AT, A) print time.time() - t0 print np.shape(res) return res # Example (returns a 23x23 matrix): # >>> results = test_my_mult(1000000) # # 0.906938076019 # (23, 23)
和clojure:
(defn feature-vec [n] (map (partial cons 1) (for [x (range n)] (take 22 (repeatedly rand))))) (defn dot-product [xy] (reduce + (map * xy))) (defn transpose "returns the transposition of a `coll` of vectors" [coll] (apply map vector coll)) (defn matrix-mult [mat1 mat2] (let [row-mult (fn [mat row] (map (partial dot-product row) (transpose mat)))] (map (partial row-mult mat2) mat1))) (defn test-my-mult [n afn] (let [xs (feature-vec n) xst (transpose xs)] (time (dorun (afn xst xs))))) ;; Example (yields a 23x23 matrix): ;; (test-my-mult 1000 i/mmult) => "Elapsed time: 32.626 msecs" ;; (test-my-mult 10000 i/mmult) => "Elapsed time: 628.841 msecs" ;; (test-my-mult 1000 matrix-mult) => "Elapsed time: 14.748 msecs" ;; (test-my-mult 10000 matrix-mult) => "Elapsed time: 434.128 msecs" ;; (test-my-mult 1000000 matrix-mult) => "Elapsed time: 375751.999 msecs" ;; Test from wikipedia ;; (def A [[14 9 3] [2 11 15] [0 12 17] [5 2 3]]) ;; (def B [[12 25] [9 10] [8 5]]) ;; user> (matrix-mult AB) ;; ((273 455) (243 235) (244 205) (102 160))
更新:我使用JBLAS库实现了相同的基准,并发现大规模的速度提升。 感谢大家的意见! 有时间把这个吸盘包装在Clojure中。 这是新的代码:
(import '[org.jblas FloatMatrix]) (defn feature-vec [n] (FloatMatrix. (into-array (for [x (range n)] (float-array (cons 1 (take 22 (repeatedly rand)))))))) (defn test-mult [n] (let [xs (feature-vec n) xst (.transpose xs)] (time (let [result (.mmul xst xs)] [(.rows result) (.columns result)])))) ;; user> (test-mult 10000) ;; "Elapsed time: 6.99 msecs" ;; [23 23] ;; user> (test-mult 100000) ;; "Elapsed time: 43.88 msecs" ;; [23 23] ;; user> (test-mult 1000000) ;; "Elapsed time: 383.439 msecs" ;; [23 23] (defn matrix-stream [rows cols] (repeatedly #(FloatMatrix/randn rows cols))) (defn square-benchmark "Times the multiplication of a square matrix." [n] (let [[abc] (matrix-stream nn)] (time (.mmuli abc)) nil)) ;; forma.matrix.jblas> (square-benchmark 10) ;; "Elapsed time: 0.113 msecs" ;; nil ;; forma.matrix.jblas> (square-benchmark 100) ;; "Elapsed time: 0.548 msecs" ;; nil ;; forma.matrix.jblas> (square-benchmark 1000) ;; "Elapsed time: 107.555 msecs" ;; nil ;; forma.matrix.jblas> (square-benchmark 2000) ;; "Elapsed time: 793.022 msecs" ;; nil
Python版本正在编译为C语言中的循环,而Clojure版本正在为每个调用映射的代码构build一个新的中间序列。 您看到的性能差异很可能来自数据结构的差异。
为了获得比这个更好的效果,你可以使用像Incanter这样的图书馆,或者按照这个SO问题的解释来编写你自己的版本。 也见这一个 。 如果你真的想留在序列,以保持懒惰的评估属性等,那么你可以通过查看瞬态内部matrix计算得到一个真正的提高
编辑:忘了添加调整clojure的第一步,打开“警告反思”
Numpy链接到BLAS / Lapack例程,这些例程在机器体系结构的层面上已经被优化了数十年,而Clojure则是以最直接和最幼稚的方式实现乘法。
任何时候你有非平凡的matrix/vector操作来执行,你应该链接到BLAS / LAPACK。
唯一不会更快的是来自语言的小matrix,在语言运行时间和LAPACK之间翻译数据表示的开销超过了计算的时间。
我刚刚在Incanter 1.3和jBLAS 1.2.1之间进行了一场小型的比赛。 代码如下:
(ns ml-class.experiments.mmult [:use [incanter core]] [:import [org.jblas DoubleMatrix]]) (defn -main [m] (let [n 23 m (Integer/parseInt m) ai (matrix (vec (double-array (* mn) (repeatedly rand))) n) ab (DoubleMatrix/rand mn) ti (copy (trans ai)) tb (.transpose ab)] (dotimes [i 20] (print "Incanter: ") (time (mmult ti ai)) (print " jBLAS: ") (time (.mmul tb ab)))))
在我的testing中,Incanter在纯matrix乘法中比jBLAS慢45%左右 。 然而,Incanter trans
函数不会创build一个matrix的新副本,因此(.mmul (.transpose ab) ab)
在jBLAS中的内存是两倍,而且只比(mmult (trans ai) ai)
快15%咒术。
鉴于馅饼丰富的function集(尤其是绘图库),我不认为我会很快切换到jBLAS。 尽pipe如此,我还是希望看到jBLAS和Parallel Colt之间的又一次枪战,也许值得考虑用Incantter中的jBLASreplaceParallel Colt? 🙂
编辑:这里是绝对数字(以毫秒为单位)我在我的(相当慢)PC上:
Incanter: 665.362452 jBLAS: 459.311598 numpy: 353.777885
对于每个图书馆,我已经select了20次的最佳时间,matrix大小为23×400000。
PS。 哈斯克尔hmatrix结果是接近numpy,但我不知道如何正确的基准。
Numpy代码使用内置库,在过去的几十年里,它是由Fortran编写的,并由作者,您的CPU供应商和您的OS分销商(以及Numpy人员)进行了优化,以获得最佳性能。 你只是做了matrix乘法的完全直接的,明显的方法。 真的,performance不一样,这并不奇怪。
但是如果你不愿意在Clojure中做这件事,可以考虑查找更好的algorithm ,使用直接循环而不是更高阶的函数(比如reduce
,或者为Javafind合适的matrix代数库(我怀疑Clojure中有好的代数库,但我真的不知道)由一位称职的math家写的。
最后,看看如何正确写出快速的Clojure。 使用types提示,在你的代码上运行一个分析器(惊喜!你的产品function使用最多的时间),并将高级function放在紧密循环中。
正如@littleidea和其他人所指出的那样,你的numpy版本使用的是LAPACK / BLAS / ATLAS,这比你在clojure中做的任何事情都要快得多,因为它已经经过了多年的精心调整。 🙂
这就是说,Clojure代码最大的问题是它使用了双打,就像盒装双打一样。 我把这称为“懒惰的双重”问题,我已经在工作中遇到了很多次。 截至目前,即使1.3,clojure的collections也不是原始的友好。 (你可以创build一个基元的vector,但它不会帮助你,因为所有的seq。函数都会把它们装箱!我还应该说1.3中的原始改进相当不错,最终帮助我们在集合中不是100%有WRT原始支持。)
在clojure中做任何types的matrix运算时,都需要使用java数组或更好的matrix库。 Incanter使用parrelcolt,但是你需要小心使用什么样的incanter函数…因为它们中的很多使得这些matrix被seqable结束了,这使得双打会给你类似的性能。 (顺便说一句,如果你认为他们会有所帮助的话,我可以发布自己的parrelcolt包装。)
为了使用BLAS库,在java-land中有几个选项。 有了所有这些选项,您必须支付JNA税…所有的数据都必须先复制,然后才能进行处理。 如果您正在进行CPU绑定操作(如matrix分解),并且其处理时间比复制数据需要的时间更长,则此税务是有意义的。 对于使用小matrix的简单操作,留在java-land中的速度可能会更快。 你只需要像上面所做的那样做一些testing,看看最适合你的是什么。
这里是你从Java使用BLAS的选项:
http://code.google.com/p/netlib-java/
- 以上API: http : //code.google.com/p/matrix-toolkits-java/
我应该指出,parrelcolt使用netlib-java项目。 这意味着,我相信,如果你正确地设置它将使用BLAS。 但是,我没有证实这一点。 有关jblas和netlib-java之间差异的解释,请参阅以下主题:我在jblas的邮件列表上开始讨论它:
http://groups.google.com/group/jblas-users/browse_thread/thread/c9b3867572331aa5
我还应该指出通用Javamatrix包库:
http://sourceforge.net/projects/ujmp/
它包装了我提到的所有库,然后一些! 我没有看太多的API,但知道他们的抽象是多么的漏洞。 这似乎是一个不错的项目。 我已经结束了使用我自己的parrelcolt clojure包装,因为他们足够快,我真的很喜欢小马的API。 (Colt使用函数对象,这意味着我只需要传递clojure函数就可以了。)
如果你想在Clojure中使用数字,我强烈推荐使用Incanter,而不是试图推出你自己的matrix函数等等。
Incanter在引擎盖下使用了Parallel Colt ,这非常快。
编辑:
截至2013年初,如果您想在Clojure中进行数字编辑,我强烈build议您查看core.matrix
Numpy针对线性代数进行了高度优化。 当然对于大型matrix来说,大部分的处理都在本地的C代码中。
为了匹配这个性能(假设它可能在Java中),你将不得不去除大部分Clojure的抽象:在迭代大型matrix时不要使用具有匿名函数的map,添加types提示以启用原始Java数组的使用等。
可能最好的select就是使用一个现成的Java数据库优化库(http://math.nist.gov/javanumerics/或类似的)。;
我没有任何具体的答案, 只是一些build议。
- 使用一个分析器来确定在哪里花费时间
- 设置警告reflection和使用types提示在需要的地方
- 你可能不得不放弃一些高层次的构造,并且用循环再次展现出最后一盎司的性能
IME,Clojure代码应该与Java非常接近(2或3X)。 但是你必须努力。
只有使用map()才有意义。 这意味着:如果你有一个特定的问题,如乘以两个matrix,不要试图映射()它,只是乘以matrix。
我倾向于只在语言意义上使用map()(即如果程序比没有程序更可读)。 乘法matrix显然是一个循环,映射它是没有意义的。
你的。
佩德罗福图尼。