Source code for mbtrack2.tracking.element

# -*- coding: utf-8 -*-
"""
This module defines the most basic elements for tracking, including Element,
an abstract base class which is to be used as mother class to every elements
included in the tracking.
"""

from abc import ABCMeta, abstractmethod
from copy import deepcopy
from functools import wraps
from typing import Callable

import numpy as np
from numpy.typing import NDArray
from scipy.special import factorial

from mbtrack2.tracking.particles import Beam, Bunch
from mbtrack2.utilities.synchrotron import Synchrotron


[docs] class Element(metaclass=ABCMeta): """ Abstract Element class used for subclass inheritance to define all kinds of objects which intervene in the tracking. """
[docs] @abstractmethod def track(self, bunch: Beam | Bunch): """ Track a Beam or Bunch object through this Element. This method needs to be overloaded in each Element subclass. Parameters ---------- beam : Beam or Bunch object """ raise NotImplementedError
[docs] @staticmethod def parallel(track: Callable) -> Callable: """ Defines the decorator @parallel which handles the embarrassingly parallel case which happens when there is no bunch to bunch interaction in the tracking routine. Adding @Element.parallel allows to write the track method of the Element subclass for a Bunch object instead of a Beam object. Parameters ---------- track : function, method of an Element subclass track method of an Element subclass which takes a Bunch object as input Returns ------- track_wrapper: function, method of an Element subclass track method of an Element subclass which takes a Beam object or a Bunch object as input """ @wraps(track) def track_wrapper(*args, **kwargs): if isinstance(args[1], Beam): self = args[0] beam = args[1] if beam.mpi_switch is True: track(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs) else: for bunch in beam.not_empty: track(self, bunch, *args[2:], **kwargs) else: self = args[0] bunch = args[1] track(self, bunch, *args[2:], **kwargs) return track_wrapper
[docs] @staticmethod def track_bunch_if_non_empty(track: Callable) -> Callable: """ Defines the decorator @track_bunch_if_non_empty which handles the case where a track method should not be called if the bunch is empty. Should be added only the track method defined for Bunch elements. Parameters ---------- track : function, method of an Element subclass track method of an Element subclass which takes a Bunch object as input Returns ------- track_wrapper: function, method of an Element subclass track method of an Element subclass which takes a Bunch object as input """ @wraps(track) def track_wrapper(*args): #self = args[0] bunch = args[1] if bunch.is_empty: pass else: track(*args) return track_wrapper
[docs] class LongitudinalMap(Element): """ Longitudinal map for a single turn in the synchrotron. Parameters ---------- ring : Synchrotron object """
[docs] def __init__(self, ring: Synchrotron): self.ring = ring
[docs] @Element.parallel def track(self, bunch: Bunch | Beam): """ Tracking method for the element. No bunch to bunch interaction, so written for Bunch objects and @Element.parallel is used to handle Beam objects. Parameters ---------- bunch : Bunch or Beam object """ bunch["delta"] -= self.ring.U0 / self.ring.E0 bunch["tau"] += self.ring.eta( bunch["delta"]) * self.ring.T0 * bunch["delta"]
[docs] class SynchrotronRadiation(Element): """ Element to handle synchrotron radiation, radiation damping and quantum excitation, for a single turn in the synchrotron. Parameters ---------- ring : Synchrotron object switch : bool array of shape (3,), optional If False in one plane (long, x, y), the synchrotron radiation is turned off. The default is True, in all three planes. qexcitation : bool, optional If False, the quantum excitation is turned off. The default is True. """
[docs] def __init__(self, ring: Synchrotron, switch: NDArray = np.ones((3, ), dtype=bool), qexcitation: bool = True): self.ring = ring self.switch = switch self.qexcitation = qexcitation
[docs] @Element.parallel def track(self, bunch: Bunch | Beam): """ Tracking method for the element. No bunch to bunch interaction, so written for Bunch objects and @Element.parallel is used to handle Beam objects. Parameters ---------- bunch : Bunch or Beam object """ N = len(bunch) excitation = 0 if self.switch[0]: if self.qexcitation: rand = np.random.standard_normal(size=N) excitation = 2 * self.ring.sigma_delta * ( self.ring.T0 / self.ring.tau[2])**0.5 * rand bunch["delta"] = (1 - 2 * self.ring.T0 / self.ring.tau[2]) * bunch["delta"] + excitation if self.switch[1]: if self.qexcitation: rand = np.random.standard_normal(size=N) excitation = 2 * self.ring.sigma()[1] * ( self.ring.T0 / self.ring.tau[0])**0.5 * rand bunch["xp"] = (1 - 2 * self.ring.T0 / self.ring.tau[0]) * bunch["xp"] + excitation if self.switch[2]: if self.qexcitation: rand = np.random.standard_normal(size=N) excitation = 2 * self.ring.sigma()[3] * ( self.ring.T0 / self.ring.tau[1])**0.5 * rand bunch["yp"] = (1 - 2 * self.ring.T0 / self.ring.tau[1]) * bunch["yp"] + excitation
[docs] class SkewQuadrupole(Element): """ Thin skew quadrupole element used to introduce betatron coupling (the length of the quadrupole is neglected). Parameters ---------- strength : float Integrated strength of the skew quadrupole [m]. """
[docs] def __init__(self, strength: float): self.strength = strength
[docs] @Element.parallel def track(self, bunch: Bunch | Beam): """ Tracking method for the element. No bunch to bunch interaction, so written for Bunch objects and @Element.parallel is used to handle Beam objects. Parameters ---------- bunch : Bunch or Beam object """ bunch["xp"] = bunch["xp"] - self.strength * bunch["y"] bunch["yp"] = bunch["yp"] - self.strength * bunch["x"]
[docs] class TransverseMapSector(Element): """ Transverse map for a sector of the synchrotron, from an initial position s0 to a final position s1. Parameters ---------- ring : Synchrotron object Ring parameters. alpha0 : array of shape (2,) Alpha Twiss function at the initial location of the sector. beta0 : array of shape (2,) Beta Twiss function at the initial location of the sector. dispersion0 : array of shape (4,) Dispersion function at the initial location of the sector. alpha1: array of shape (2,) Alpha Twiss function at the final location of the sector. beta1 : array of shape (2,) Beta Twiss function at the final location of the sector. dispersion1 : array of shape (4,) Dispersion function at the final location of the sector. phase_diff : array of shape (2,) Phase difference between the initial and final location of the sector. chro_diff : array of shape (2,) Chromaticity difference between the initial and final location of the sector. adts : array of shape (4,), optional Amplitude-dependent tune shift of the sector, see Synchrotron class for details. The default is None. """
[docs] def __init__(self, ring: Synchrotron, alpha0: NDArray, beta0: NDArray, dispersion0: NDArray, alpha1: NDArray, beta1: NDArray, dispersion1: NDArray, phase_diff: NDArray, chro_diff: NDArray, adts: NDArray | None): self.ring = ring self.alpha0 = alpha0 self.beta0 = beta0 self.gamma0 = (1 + self.alpha0**2) / self.beta0 self.dispersion0 = dispersion0 self.alpha1 = alpha1 self.beta1 = beta1 self.gamma1 = (1 + self.alpha1**2) / self.beta1 self.dispersion1 = dispersion1 self.tune_diff = phase_diff / (2 * np.pi) self.chro_diff = chro_diff if adts is not None: self.adts_poly = [ np.poly1d(adts[0]), np.poly1d(adts[1]), np.poly1d(adts[2]), np.poly1d(adts[3]), ] else: self.adts_poly = None
[docs] def _compute_chromatic_tune_advances( self, bunch: Beam | Bunch) -> tuple[float, float]: order = len(self.chro_diff) // 2 if order == 1: tune_advance_x = self.chro_diff[0] * bunch["delta"] tune_advance_y = self.chro_diff[1] * bunch["delta"] elif order == 2: tune_advance_x = (self.chro_diff[0] * bunch["delta"] + self.chro_diff[2] / 2 * bunch["delta"]**2) tune_advance_y = (self.chro_diff[1] * bunch["delta"] + self.chro_diff[3] / 2 * bunch["delta"]**2) elif order == 3: tune_advance_x = (self.chro_diff[0] * bunch["delta"] + self.chro_diff[2] / 2 * bunch["delta"]**2 + self.chro_diff[4] / 6 * bunch["delta"]**3) tune_advance_y = (self.chro_diff[1] * bunch["delta"] + self.chro_diff[3] / 2 * bunch["delta"]**2 + self.chro_diff[5] / 6 * bunch["delta"]**3) elif order == 4: tune_advance_x = (self.chro_diff[0] * bunch["delta"] + self.chro_diff[2] / 2 * bunch["delta"]**2 + self.chro_diff[4] / 6 * bunch["delta"]**3 + self.chro_diff[6] / 24 * bunch["delta"]**4) tune_advance_y = (self.chro_diff[1] * bunch["delta"] + self.chro_diff[3] / 2 * bunch["delta"]**2 + self.chro_diff[5] / 6 * bunch["delta"]**3 + self.chro_diff[7] / 24 * bunch["delta"]**4) else: coefs = np.array([1 / factorial(i) for i in range(order + 1)]) coefs[0] = 0 self.chro_diff = np.concatenate(([0, 0], self.chro_diff)) tune_advance_x = np.polynomial.polynomial.Polynomial( self.chro_diff[::2] * coefs)(bunch['delta']) tune_advance_y = np.polynomial.polynomial.Polynomial( self.chro_diff[1::2] * coefs)(bunch['delta']) return tune_advance_x, tune_advance_y
[docs] def _compute_new_coords(self, bunch, tune_advance, plane): if plane == 'x': i, j, coord, mom = 0, 0, 'x', 'xp' elif plane == 'y': i, j, coord, mom = 1, 2, 'y', 'yp' else: raise ValueError('plane should be either x or y') c_u = np.cos(2 * np.pi * tune_advance) s_u = np.sin(2 * np.pi * tune_advance) M00 = np.sqrt( self.beta1[i] / self.beta0[i]) * (c_u + self.alpha0[i] * s_u) M01 = np.sqrt(self.beta0[i] * self.beta1[i]) * s_u M02 = (self.dispersion1[j] - M00 * self.dispersion0[j] - M01 * self.dispersion0[j + 1]) M10 = ((self.alpha0[i] - self.alpha1[i]) * c_u - (1 + self.alpha0[i] * self.alpha1[i]) * s_u) / np.sqrt( self.beta0[i] * self.beta1[i]) M11 = np.sqrt( self.beta0[i] / self.beta1[i]) * (c_u - self.alpha1[i] * s_u) M12 = (self.dispersion1[j + 1] - M10 * self.dispersion0[j] - M11 * self.dispersion0[j + 1]) u = M00 * bunch[coord] + M01 * bunch[mom] + M02 * bunch["delta"] up = M10 * bunch[coord] + M11 * bunch[mom] + M12 * bunch["delta"] return u, up
[docs] @Element.parallel def track(self, bunch: Bunch | Beam): """ Tracking method for the element. No bunch to bunch interaction, so written for Bunch objects and @Element.parallel is used to handle Beam objects. Parameters ---------- bunch : Bunch or Beam object """ tune_advance_x = self.tune_diff[0] tune_advance_y = self.tune_diff[1] # Compute tune advance which depends on energy via chromaticity and ADTS if (np.array(self.chro_diff) != 0).any(): tune_advance_x_chro, tune_advance_y_chro = self._compute_chromatic_tune_advances( bunch) tune_advance_x += tune_advance_x_chro tune_advance_y += tune_advance_y_chro if self.adts_poly is not None: Jx = ((self.gamma0[0] * bunch["x"]**2) + (2 * self.alpha0[0] * bunch["x"] * bunch["xp"]) + (self.beta0[0] * bunch["xp"]**2)) Jy = ((self.gamma0[1] * bunch["y"]**2) + (2 * self.alpha0[1] * bunch["y"] * bunch["yp"]) + (self.beta0[1] * bunch["yp"]**2)) tune_advance_x += (self.adts_poly[0](Jx) + self.adts_poly[2](Jy)) tune_advance_y += (self.adts_poly[1](Jx) + self.adts_poly[3](Jy)) bunch['x'], bunch['xp'] = self._compute_new_coords( bunch, tune_advance_x, 'x') bunch['y'], bunch['yp'] = self._compute_new_coords( bunch, tune_advance_y, 'y')
[docs] class TransverseMap(TransverseMapSector): """ Transverse map for a single turn in the synchrotron. Parameters ---------- ring : Synchrotron object """
[docs] def __init__(self, ring: Synchrotron): super().__init__(ring, ring.optics.local_alpha, ring.optics.local_beta, ring.optics.local_dispersion, ring.optics.local_alpha, ring.optics.local_beta, ring.optics.local_dispersion, 2 * np.pi * ring.tune, ring.chro, ring.adts)
[docs] def transverse_map_sector_generator(ring: Synchrotron, positions: NDArray, **kwargs) -> list[TransverseMapSector]: """ Convenience function which generate a list of TransverseMapSector elements from a ring: - if an AT lattice is loaded, the optics functions and chromaticity is computed at the given positions. - if no AT lattice is loaded, the local optics are used everywhere. Tracking through all the sectors is equivalent to a full turn (and thus to the TransverseMap object). Parameters ---------- ring : Synchrotron object Ring parameters. positions : array list of longitudinal positions in [m] to use as starting and end points of the TransverseMapSector elements. The array should contain the initial position (s=0) but not the end position (s=ring.L), so like position = np.array([0, pos1, pos2, ...]). See at.physics.nonlinear.chromaticity for **kwargs Returns ------- sectors : list list of TransverseMapSector elements. """ N_sec = len(positions) sectors = [] if hasattr(ring, "adts") and ring.adts is not None: adts = np.array([val / N_sec for val in ring.adts]) else: adts = None if ring.optics.use_local_values: for i in range(N_sec): sectors.append( TransverseMapSector(ring, ring.optics.local_alpha, ring.optics.local_beta, ring.optics.local_dispersion, ring.optics.local_alpha, ring.optics.local_beta, ring.optics.local_dispersion, 2 * np.pi * ring.tune / N_sec, np.asarray(ring.chro) / N_sec, adts=adts)) else: import at dp = kwargs.get('dp', 1e-2) order = kwargs.get('order', 1) def _compute_chro(ring, N_sec, dp, order): lat = deepcopy(ring.optics.lattice) lat.append(at.Marker("END")) fit, _, _ = at.physics.nonlinear.chromaticity(lat, method='linopt', dpm=dp, n_points=100, order=order) chro_xy = [ elem for pair in zip(fit[0, 1:], fit[1, 1:]) for elem in pair ] len_chro = int(order * 2) _chro = np.zeros((len_chro, N_sec)) for i in range(len_chro): chro_order_splited = np.linspace(0, chro_xy[i], N_sec) _chro[i, :] = chro_order_splited return _chro _chro = _compute_chro(ring, N_sec, dp, order) for i in range(N_sec): alpha0 = ring.optics.alpha(positions[i]) beta0 = ring.optics.beta(positions[i]) dispersion0 = ring.optics.dispersion(positions[i]) mu0 = ring.optics.mu(positions[i]) chro0 = _chro[:, i] if i != (N_sec - 1): alpha1 = ring.optics.alpha(positions[i + 1]) beta1 = ring.optics.beta(positions[i + 1]) dispersion1 = ring.optics.dispersion(positions[i + 1]) mu1 = ring.optics.mu(positions[i + 1]) chro1 = _chro[:, i + 1] else: alpha1 = ring.optics.alpha(positions[0]) beta1 = ring.optics.beta(positions[0]) dispersion1 = ring.optics.dispersion(positions[0]) mu1 = ring.optics.mu(ring.L) chro1 = _chro[:, -1] phase_diff = mu1 - mu0 chro_diff = chro1 - chro0 sectors.append( TransverseMapSector(ring, alpha0, beta0, dispersion0, alpha1, beta1, dispersion1, phase_diff, chro_diff, adts=adts)) return sectors