utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import pickle
  2. import numpy as np
  3. def loadData(fileName, timestep, absolute=False):
  4. """Load the data in data/fileName/ at the given timestep.
  5. Parameters:
  6. fileName (string): data is loaded from data/fileName/
  7. timestep (integer): timestep at which to load the data
  8. absolute (bool): true to load relative to parent folder.
  9. Used for interactive jupyter notebook.
  10. Returns:
  11. data instance that was saved by Simulation.save
  12. """
  13. prepend = './..' if absolute else './'
  14. return pickle.load(open(prepend+'/data/{}/data{}.pickle'.format(fileName, timestep), "rb" ))
  15. def calculateEccentricity(M, ref_r, ref_v, r, v):
  16. """Calculates the orbital eccentricity of a list of particles.
  17. Parameters:
  18. M (string): mass of the central object
  19. ref_r (array): reference position of the central object
  20. ref_v (array): reference velocity of the central object
  21. r (array): list of positions of all particles (n particles, 3)
  22. v (array): list of velocities of all particles (n particles, 3)
  23. Returns:
  24. Array of shape (n particles, 3) with the eccentricities of the orbits
  25. """
  26. v1 = v - ref_v
  27. r1 = r - ref_r
  28. h_vec = np.cross(r1, v1)
  29. h = np.linalg.norm(h_vec, axis=-1)
  30. # Make use of $\vec{e} = (v \times (r \times v)) / M$, and e = |\vec{e}|
  31. e_vec = np.cross(v1, h_vec)/M - r1/np.linalg.norm(r1, axis=-1, keepdims=True)
  32. e = np.linalg.norm(e_vec, axis=-1)
  33. return e
  34. ############################# PLOTTING UTILS ########################
  35. # Matplotlib can be tedious at times. These short functions make the
  36. # repetitive parts simpler and ensure consistency.
  37. def plotCOM(ax):
  38. """Plots cross at (0,0)"""
  39. ax.scatter([0],[0],marker='+', c='black', s=500,
  40. alpha=.5, linewidth=1, zorder=-1)
  41. def plotCenterMasses(ax, data):
  42. """Plots stars at the position of the central masses"""
  43. ax.scatter(data['r_vec'][data['type'][:,0]=='center'][:,0],
  44. data['r_vec'][data['type'][:,0]=='center'][:,1],
  45. s=100, marker="*", c='black', alpha=.7)
  46. def plotTracks(ax, tracks):
  47. """Plots the tracks of the central masses"""
  48. for track in tracks:
  49. ax.plot(track[:,0], track[:,1],
  50. c='black', alpha=1.0, linewidth=1)
  51. def setSize(ax, x=None, y=None, mode=None):
  52. """Sets the size of the plot. If mode='square', the x and y axis
  53. will have the same scale."""
  54. if mode=='square': ax.axis('square')
  55. if x is not None: ax.set_xlim(x)
  56. if y is not None: ax.set_ylim(y)
  57. def setAxes(ax, x=None, y=None, xcoords=None, ycoords=None, mode=None):
  58. """"Sets the axis labels (x, y) and their position (None). The mode
  59. keyword can be used to hide all the axes ('hide'), only plot the bottom
  60. axis ('bottom') or only plot the bottom and left axes ('bottomleft')"""
  61. if x is not None: ax.set_xlabel(x)
  62. if y is not None: ax.set_ylabel(y)
  63. if mode=='hide':
  64. ax.spines['right'].set_visible(False)
  65. ax.spines['top'].set_visible(False)
  66. ax.spines['left'].set_visible(False)
  67. ax.spines['bottom'].set_visible(False)
  68. ax.set_xticklabels([])
  69. ax.set_xticks([])
  70. ax.set_yticklabels([])
  71. ax.set_yticks([])
  72. elif mode=='bottomleft':
  73. ax.spines['right'].set_visible(False)
  74. ax.spines['top'].set_visible(False)
  75. elif mode=='bottom':
  76. ax.get_yaxis().set_ticks([])
  77. ax.set_ylabel('')
  78. ax.spines['left'].set_visible(False)
  79. ax.spines['right'].set_visible(False)
  80. ax.spines['top'].set_visible(False)
  81. if xcoords is not None: ax.xaxis.set_label_coords(*xcoords)
  82. if ycoords is not None: ax.yaxis.set_label_coords(*ycoords)
  83. def stylizePlot(axs):
  84. """Adds ticks (a la root style) for prettier plots"""
  85. for ax in axs:
  86. ax.tick_params(axis='both', which='both', direction='in', top=True, right=True)
  87. ax.minorticks_on()