segmentation.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """Segmentation algorithm used to identify the different structures
  2. that are formed in the encounter. This file can be called from the
  3. command line to make an illustrative plot of the algorithm.
  4. """
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import matplotlib.patches as patches
  8. import utils
  9. def segmentEncounter(data, plot=False, mode='all'):
  10. """Segment the encounter into tail, bridge, orbitting and
  11. stolen particles as described in the report.
  12. Parameters:
  13. data: A data instance as saved by the simulation to a pickle file
  14. plot: If true the segmentation will be plotted and shown. Highly
  15. useful for debugging.
  16. mode (string): If mode is 'all' all parts of the encounter will be
  17. identified. If mode is 'bridge' only the bridge will be
  18. identified. This is useful when there may be no tail.
  19. Returns:
  20. masks (tupple): tupple of array corresponding to the masks of the
  21. (bridge, stolen, orbitting, tail) particles. One can then use
  22. e.g. data['r_vec'][bridgeMask].
  23. shape (tupple): tupple of (distances, angles) as measured from the
  24. center of mass and with respect to the x axis. They define the
  25. shape of the tail
  26. length (float): total length of the tail.
  27. """
  28. nRings = 100 # number of rings to use when segmenting the data
  29. # Localize the central masses
  30. r_vec = data['r_vec']
  31. centers = r_vec[data['type'][:,0]=='center']
  32. rCenters_vec = centers[1] - centers[0]
  33. rCenters = np.linalg.norm(rCenters_vec)
  34. rCenters_unit = rCenters_vec/np.linalg.norm(rCenters_vec)
  35. # Take particles to be on the tail a priori and
  36. # remove them as they are found in other structures
  37. particlesLeft = np.arange(0, len(r_vec))
  38. if plot:
  39. colour = '#c40f4c'
  40. f, axs = plt.subplots(1, 3, figsize=(9, 4), sharey=False)
  41. f.subplots_adjust(hspace=0, wspace=0)
  42. axs[0].scatter(r_vec[:,0], r_vec[:,1], c=colour, alpha=0.1, s=0.1)
  43. axs[0].axis('equal')
  44. utils.plotCenterMasses(axs[0], data)
  45. axs[0].axis('off')
  46. # Step 1: project points to see if they are part of the bridge
  47. parallelProjection = np.dot(r_vec - centers[0], rCenters_unit)
  48. perpendicularProjection = np.linalg.norm(r_vec - centers[0][np.newaxis]
  49. - parallelProjection[:,np.newaxis] * rCenters_unit[np.newaxis], axis=-1)
  50. bridgeMask = np.logical_and(np.logical_and(0.3*rCenters < parallelProjection,
  51. parallelProjection < .7*rCenters), perpendicularProjection < 2)
  52. # Remove the bridge
  53. notInBridge = np.logical_not(bridgeMask)
  54. r_vec = r_vec[notInBridge]
  55. particlesLeft = particlesLeft[notInBridge]
  56. if mode == 'bridge':
  57. return (bridgeMask, None, None, None), None, None
  58. # Step 2: select stolen particles by checking distance to centers
  59. stolenMask = (np.linalg.norm(r_vec - centers[0][np.newaxis], axis=-1)
  60. > np.linalg.norm(r_vec - centers[1][np.newaxis], axis=-1))
  61. # Remove the stolen part
  62. notStolen = np.logical_not(stolenMask)
  63. r_vec = r_vec[notStolen]
  64. particlesLeft, stolenMask = particlesLeft[notStolen], particlesLeft[stolenMask]
  65. # Step 3: segment data into concentric rings (spherical shells really)
  66. r_vec = r_vec - centers[0]
  67. r = np.linalg.norm(r_vec, axis=-1)
  68. edges = np.linspace(0, 30, nRings) # nRings concentric spheres
  69. indices = np.digitize(r, edges) # Classify particles into shells
  70. if plot:
  71. axs[1].scatter(r_vec[:,0], r_vec[:,1], c=colour, alpha=.1, s=.1)
  72. axs[1].axis('equal')
  73. axs[1].scatter(0, 0, s=100, marker="*", c='black', alpha=.7)
  74. axs[1].axis('off')
  75. # Step 4: find start of tail
  76. start = None
  77. for i in range(1, nRings+1):
  78. rMean = np.mean(r[indices==i])
  79. rMean_vec = np.mean(r_vec[indices==i], axis=0)
  80. parameter = np.linalg.norm(rMean_vec)/rMean
  81. if plot:
  82. circ = patches.Circle((0,0), edges[i-1], linewidth=0.5,edgecolor='black',facecolor='none', alpha=.7)
  83. axs[1].add_patch(circ)
  84. txtxy = edges[i-1] * np.array([np.sin(i/13), np.cos(i/13)])
  85. axs[1].annotate("{:.2f}".format(parameter), xy=txtxy, backgroundcolor='#ffffff55')
  86. if start is None and parameter>.8 :
  87. start = i #Here starts the tail
  88. startDirection = rMean_vec/np.linalg.norm(rMean_vec)
  89. if not plot: break;
  90. if start is None: #abort if nothing found
  91. raise Exception('Could not identify tail')
  92. # Step 5: remove all circles before start
  93. inInnerRings = indices < start
  94. # Remove particles on the opposite direction to startDirection.
  95. # in the now innermost 5 rings. Likely traces of the bridge.
  96. oppositeDirection = np.dot(r_vec, startDirection) < 0
  97. in5InnermostRings = indices <= start+5
  98. orbitting = np.logical_or(inInnerRings,
  99. np.logical_and(oppositeDirection, in5InnermostRings))
  100. orbittingMask = particlesLeft[orbitting]
  101. r_vec = r_vec[np.logical_not(orbitting)]
  102. tailMask = particlesLeft[np.logical_not(orbitting)]
  103. if plot:
  104. axs[2].scatter(r_vec[:,0], r_vec[:,1], c=colour, alpha=0.1, s=0.1)
  105. axs[2].axis('equal')
  106. axs[2].scatter(0, 0, s=100, marker="*", c='black', alpha=.7)
  107. axs[2].axis('off')
  108. # Step 6: measure tail length and shape
  109. r = np.linalg.norm(r_vec, axis=-1)
  110. indices = np.digitize(r, edges)
  111. # Make list of barycenters
  112. points = [list(np.mean(r_vec[indices==i], axis=0))
  113. for i in range(1, nRings) if len(r_vec[indices==i])!=0]
  114. points = np.array(points)
  115. # Calculate total length
  116. lengths = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
  117. length = np.sum(lengths)
  118. # Shape (for 2D only)
  119. angles = np.arctan2(points[:,1], points[:,0])
  120. distances = np.linalg.norm(points, axis=-1)
  121. shape = (distances, angles)
  122. if plot:
  123. axs[2].plot(points[:,0], points[:,1], c='black', linewidth=0.5, marker='+')
  124. if plot:
  125. plt.show()
  126. return (bridgeMask, stolenMask, orbittingMask, tailMask), shape, length
  127. if __name__ == "__main__":
  128. data = utils.loadData('200mass', 10400)
  129. segmentEncounter(data, plot=True)