1
0

unittests.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import unittest
  2. from acceleration import *
  3. import numpy as np
  4. from numpy.testing import assert_almost_equal
  5. accelerationFunctions = [bruteForce, bruteForceNumba, bruteForceNumbaOptimized, bruteForceCPP]
  6. class MyTest(unittest.TestCase):
  7. # Massive particles
  8. def test1(self):
  9. r_vec = np.array([[0.,0,0],[1,0,0]])
  10. m_vec = np.array([1., 1])
  11. result = np.array([[1., 0, 0],[-1, 0, 0]])
  12. for fun in accelerationFunctions:
  13. assert_almost_equal(fun(r_vec, m_vec, 0), result)
  14. # Massless particles
  15. def test2(self):
  16. r_vec = np.array([[0.,0,0],[1,0,0]])
  17. m_vec = np.array([1., 0])
  18. result = np.array([[0, 0, 0],[-1, 0, 0]])
  19. for fun in accelerationFunctions:
  20. assert_almost_equal(fun(r_vec, m_vec, 0), result)
  21. # Softening
  22. def test3(self):
  23. r_vec = np.array([[0.,0,0],[1,0,0]])
  24. m_vec = np.array([1., 1])
  25. result = np.array([[1/(1+.1)**2., 0, 0],[-1/(1+.1)**2, 0, 0]])
  26. for fun in accelerationFunctions:
  27. assert_almost_equal(fun(r_vec, m_vec, 0.1), result)
  28. # 3 dimensions
  29. def test4(self):
  30. r_vec = np.array([[0.,0,0],[1,1,1]])
  31. m_vec = np.array([1., 0])
  32. result = np.array([[0, 0, 0],[-1, -1, -1]])/np.sqrt(3)**3
  33. for fun in accelerationFunctions:
  34. assert_almost_equal(fun(r_vec, m_vec, 0.), result)
  35. # More particles
  36. def test5(self):
  37. r_vec = np.array([[1.,0,0],[0,0,0],[2,0,0],[3,0,0]])
  38. m_vec = np.array([1., 0, 0, 0])
  39. result = np.array([[0, 0, 0],[1, 0, 0],[-1, 0, 0],[-1/4, 0, 0]])
  40. for fun in accelerationFunctions:
  41. assert_almost_equal(fun(r_vec, m_vec, 0.), result)
  42. # Many particles: self consistency
  43. def test5(self):
  44. r_vec = np.random.rand(100, 3)
  45. m_vec = np.random.rand(100)
  46. result = bruteForce(r_vec, m_vec, 0.1)
  47. for fun in [bruteForceNumba, bruteForceNumbaOptimized, bruteForceCPP]:
  48. assert_almost_equal(fun(r_vec, m_vec, 0.1), result)
  49. if __name__ == '__main__':
  50. unittest.main()