Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1"""Axial stretching test-case
  2
  3Assume we have a rod lying aligned in the x-direction, with high internal
  4damping.
  5
  6We fix one end (say, the left end) of the rod to a wall. On the right
  7end we apply a force directed axially pulling the rods tip. Linear
  8theory (assuming small displacements) predict that the net displacement
  9experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10usual meaning (the rod is just a linear spring). We compare our results
 11with the above result.
 12
 13We can "improve" the theory by having a better estimate for the rod's
 14spring constant by assuming that it equilibriates under the new position,
 15with
 16Δx = F * (L + Δx)/ (A * E)
 17which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18this position.
 19
 20Note that if the damping is not high, the rod oscillates about the eventual
 21resting position (and this agrees with the theoretical predictions without
 22any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24isort:skip_file
 25"""
 26
 27import numpy as np
 28from matplotlib import pyplot as plt
 29
 30import elastica as ea
 31
 32
 33class StretchingBeamSimulator(
 34    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
 35):
 36    pass
 37
 38
 39stretch_sim = StretchingBeamSimulator()
 40final_time = 200.0
 41
 42# Options
 43PLOT_FIGURE = True
 44SAVE_FIGURE = False
 45SAVE_RESULTS = False
 46
 47# setting up test params
 48n_elem = 19
 49start = np.zeros((3,))
 50direction = np.array([1.0, 0.0, 0.0])
 51normal = np.array([0.0, 1.0, 0.0])
 52base_length = 1.0
 53base_radius = 0.025
 54base_area = np.pi * base_radius**2
 55density = 1000
 56youngs_modulus = 1e4
 57# For shear modulus of 1e4, nu is 99!
 58poisson_ratio = 0.5
 59shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 60
 61stretchable_rod = ea.CosseratRod.straight_rod(
 62    n_elem,
 63    start,
 64    direction,
 65    normal,
 66    base_length,
 67    base_radius,
 68    density,
 69    youngs_modulus=youngs_modulus,
 70    shear_modulus=shear_modulus,
 71)
 72
 73stretch_sim.append(stretchable_rod)
 74stretch_sim.constrain(stretchable_rod).using(
 75    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 76)
 77
 78end_force_x = 1.0
 79end_force = np.array([end_force_x, 0.0, 0.0])
 80stretch_sim.add_forcing_to(stretchable_rod).using(
 81    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 82)
 83
 84# add damping
 85dl = base_length / n_elem
 86dt = 0.1 * dl
 87damping_constant = 0.1
 88stretch_sim.dampen(stretchable_rod).using(
 89    ea.AnalyticalLinearDamper,
 90    damping_constant=damping_constant,
 91    time_step=dt,
 92)
 93
 94
 95# Add call backs
 96class AxialStretchingCallBack(ea.CallBackBaseClass):
 97    """
 98    Tracks the velocity norms of the rod
 99    """
100
101    def __init__(self, step_skip: int, callback_params: dict) -> None:
102        super().__init__()
103        self.every = step_skip
104        self.callback_params = callback_params
105
106    def make_callback(
107        self, system: ea.typing.RodType, time: np.float64, current_step: int
108    ) -> None:
109
110        if current_step % self.every == 0:
111
112            self.callback_params["time"].append(time)
113            # Collect only x
114            self.callback_params["position"].append(
115                system.position_collection[0, -1].copy()
116            )
117            self.callback_params["velocity_norms"].append(
118                np.linalg.norm(system.velocity_collection.copy())
119            )
120            return
121
122
123recorded_history: dict[str, list] = ea.defaultdict(list)
124stretch_sim.collect_diagnostics(stretchable_rod).using(
125    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
126)
127
128stretch_sim.finalize()
129timestepper: ea.typing.StepperProtocol = ea.PositionVerlet()
130# timestepper = PEFRL()
131
132total_steps = int(final_time / dt)
133print("Total steps", total_steps)
134ea.integrate(timestepper, stretch_sim, final_time, total_steps)
135
136if PLOT_FIGURE:
137    # First-order theory with base-length
138    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
139    # First-order theory with modified-length, gives better estimates
140    expected_tip_disp_improved = (
141        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
142    )
143
144    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
145    ax = fig.add_subplot(111)
146    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
147    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
148    ax.hlines(
149        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
150    )
151    if SAVE_FIGURE:
152        fig.savefig("axial_stretching.pdf")
153    plt.show()
154
155if SAVE_RESULTS:
156    import pickle
157
158    filename = "axial_stretching_data.dat"
159    file = open(filename, "wb")
160    pickle.dump(stretchable_rod, file)
161    file.close()
162
163    tv = (
164        np.asarray(recorded_history["time"]),
165        np.asarray(recorded_history["velocity_norms"]),
166    )
167
168    def as_time_series(v: np.ndarray) -> np.ndarray:
169        return v.T
170
171    np.savetxt(
172        "velocity_norms.csv",
173        as_time_series(np.stack(tv)),
174        delimiter=",",
175    )

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
  7
  8
  9class TimoshenkoBeamSimulator(
 10    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.CallBacks, ea.Damping
 11):
 12    pass
 13
 14
 15timoshenko_sim = TimoshenkoBeamSimulator()
 16final_time = 5000.0
 17
 18# Options
 19PLOT_FIGURE = True
 20SAVE_FIGURE = True
 21SAVE_RESULTS = False
 22ADD_UNSHEARABLE_ROD = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 3.0
 30base_radius = 0.25
 31base_area = np.pi * base_radius**2
 32density = 5000
 33nu = 0.1 / 7 / density / base_area
 34E = 1e6
 35# For shear modulus of 1e4, nu is 99!
 36poisson_ratio = 99
 37shear_modulus = E / (poisson_ratio + 1.0)
 38
 39shearable_rod = ea.CosseratRod.straight_rod(
 40    n_elem,
 41    start,
 42    direction,
 43    normal,
 44    base_length,
 45    base_radius,
 46    density,
 47    youngs_modulus=E,
 48    shear_modulus=shear_modulus,
 49)
 50
 51timoshenko_sim.append(shearable_rod)
 52# add damping
 53dl = base_length / n_elem
 54dt = 0.07 * dl
 55timoshenko_sim.dampen(shearable_rod).using(
 56    ea.AnalyticalLinearDamper,
 57    damping_constant=nu,
 58    time_step=dt,
 59)
 60
 61timoshenko_sim.constrain(shearable_rod).using(
 62    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 63)
 64
 65end_force = np.array([-15.0, 0.0, 0.0])
 66timoshenko_sim.add_forcing_to(shearable_rod).using(
 67    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 68)
 69
 70
 71if ADD_UNSHEARABLE_ROD:
 72    # Start into the plane
 73    unshearable_start = np.array([0.0, -1.0, 0.0])
 74    shear_modulus = E / (-0.7 + 1.0)
 75    unshearable_rod = ea.CosseratRod.straight_rod(
 76        n_elem,
 77        unshearable_start,
 78        direction,
 79        normal,
 80        base_length,
 81        base_radius,
 82        density,
 83        youngs_modulus=E,
 84        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 85        shear_modulus=shear_modulus,
 86    )
 87
 88    timoshenko_sim.append(unshearable_rod)
 89
 90    # add damping
 91    timoshenko_sim.dampen(unshearable_rod).using(
 92        ea.AnalyticalLinearDamper,
 93        damping_constant=nu,
 94        time_step=dt,
 95    )
 96    timoshenko_sim.constrain(unshearable_rod).using(
 97        ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 98    )
 99    timoshenko_sim.add_forcing_to(unshearable_rod).using(
100        ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
101    )
102
103
104# Add call backs
105class VelocityCallBack(ea.CallBackBaseClass):
106    """
107    Tracks the velocity norms of the rod
108    """
109
110    def __init__(self, step_skip: int, callback_params: dict):
111        ea.CallBackBaseClass.__init__(self)
112        self.every = step_skip
113        self.callback_params = callback_params
114
115    def make_callback(self, system, time, current_step: int):
116
117        if current_step % self.every == 0:
118
119            self.callback_params["time"].append(time)
120            # Collect x
121            self.callback_params["velocity_norms"].append(
122                np.linalg.norm(system.velocity_collection.copy())
123            )
124            return
125
126
127recorded_history = ea.defaultdict(list)
128timoshenko_sim.collect_diagnostics(shearable_rod).using(
129    VelocityCallBack, step_skip=500, callback_params=recorded_history
130)
131
132timoshenko_sim.finalize()
133timestepper = ea.PositionVerlet()
134# timestepper = PEFRL()
135
136total_steps = int(final_time / dt)
137print("Total steps", total_steps)
138ea.integrate(timestepper, timoshenko_sim, final_time, total_steps)
139
140if PLOT_FIGURE:
141    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
142
143if SAVE_RESULTS:
144    import pickle
145
146    filename = "Timoshenko_beam_data.dat"
147    file = open(filename, "wb")
148    pickle.dump(shearable_rod, file)
149    file.close()
150
151    tv = (
152        np.asarray(recorded_history["time"]),
153        np.asarray(recorded_history["velocity_norms"]),
154    )
155
156    def as_time_series(v):
157        return v.T
158
159    np.savetxt(
160        "velocity_norms.csv",
161        as_time_series(np.stack(tv)),
162        delimiter=",",
163    )

Butterfly#

  1import numpy as np
  2from matplotlib import pyplot as plt
  3from matplotlib.colors import to_rgb
  4
  5
  6import elastica as ea
  7from elastica.utils import MaxDimension
  8
  9
 10class ButterflySimulator(ea.BaseSystemCollection, ea.CallBacks):
 11    pass
 12
 13
 14butterfly_sim = ButterflySimulator()
 15final_time = 40.0
 16
 17# Options
 18PLOT_FIGURE = True
 19SAVE_FIGURE = True
 20SAVE_RESULTS = True
 21ADD_UNSHEARABLE_ROD = False
 22
 23# setting up test params
 24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 25n_elem = 4  # Change based on requirements, but be careful
 26n_elem += n_elem % 2
 27half_n_elem = n_elem // 2
 28
 29origin = np.zeros((3, 1))
 30angle_of_inclination = np.deg2rad(45.0)
 31
 32# in-plane
 33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 35
 36# out-of-plane
 37normal = np.array([0.0, 1.0, 0.0])
 38
 39total_length = 3.0
 40base_radius = 0.25
 41base_area = np.pi * base_radius**2
 42density = 5000
 43youngs_modulus = 1e4
 44poisson_ratio = 0.5
 45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 46
 47positions = np.empty((MaxDimension.value(), n_elem + 1))
 48dl = total_length / n_elem
 49
 50# First half of positions stem from slope angle_of_inclination
 51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 53    np.cos(angle_of_inclination) * horizontal_direction
 54    + np.sin(angle_of_inclination) * vertical_direction
 55)
 56positions[..., half_n_elem:] = positions[
 57    ..., half_n_elem : half_n_elem + 1
 58] + dl * first_half * (
 59    np.cos(angle_of_inclination) * horizontal_direction
 60    - np.sin(angle_of_inclination) * vertical_direction
 61)
 62
 63butterfly_rod = ea.CosseratRod.straight_rod(
 64    n_elem,
 65    start=origin.reshape(3),
 66    direction=np.array([0.0, 0.0, 1.0]),
 67    normal=normal,
 68    base_length=total_length,
 69    base_radius=base_radius,
 70    density=density,
 71    youngs_modulus=youngs_modulus,
 72    shear_modulus=shear_modulus,
 73    position=positions,
 74)
 75
 76butterfly_sim.append(butterfly_rod)
 77
 78
 79# Add call backs
 80class VelocityCallBack(ea.CallBackBaseClass):
 81    """
 82    Call back function for continuum snake
 83    """
 84
 85    def __init__(self, step_skip: int, callback_params: dict) -> None:
 86        ea.CallBackBaseClass.__init__(self)
 87        self.every = step_skip
 88        self.callback_params = callback_params
 89
 90    def make_callback(
 91        self, system: ea.typing.RodType, time: np.float64, current_step: int
 92    ) -> None:
 93
 94        if current_step % self.every == 0:
 95
 96            self.callback_params["time"].append(time)
 97            # Collect x
 98            self.callback_params["position"].append(system.position_collection.copy())
 99            # Collect energies as well
100            self.callback_params["te"].append(system.compute_translational_energy())
101            self.callback_params["re"].append(system.compute_rotational_energy())
102            self.callback_params["se"].append(system.compute_shear_energy())
103            self.callback_params["be"].append(system.compute_bending_energy())
104            return
105
106
107recorded_history: dict[str, list] = ea.defaultdict(list)
108# initially record history
109recorded_history["time"].append(0.0)
110recorded_history["position"].append(butterfly_rod.position_collection.copy())
111recorded_history["te"].append(butterfly_rod.compute_translational_energy())
112recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
113recorded_history["se"].append(butterfly_rod.compute_shear_energy())
114recorded_history["be"].append(butterfly_rod.compute_bending_energy())
115
116butterfly_sim.collect_diagnostics(butterfly_rod).using(
117    VelocityCallBack, step_skip=100, callback_params=recorded_history
118)
119
120
121butterfly_sim.finalize()
122timestepper: ea.typing.StepperProtocol
123timestepper = ea.PositionVerlet()
124# timestepper = PEFRL()
125
126dt = 0.01 * dl
127total_steps = int(final_time / dt)
128print("Total steps", total_steps)
129ea.integrate(timestepper, butterfly_sim, final_time, total_steps)
130
131if PLOT_FIGURE:
132    # Plot the histories
133    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
134    ax = fig.add_subplot(111)
135    positions_history = recorded_history["position"]
136    # record first position
137    first_position = positions_history.pop(0)
138    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
139    n_positions = len(positions_history)
140    for i, pos in enumerate(positions_history):
141        alpha = np.exp(i / n_positions - 1)
142        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
143    # final position is also separate
144    last_position = positions_history.pop()
145    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
146    # don't block
147    fig.show()
148
149    # Plot the energies
150    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
151    energy_ax = energy_fig.add_subplot(111)
152    times = np.asarray(recorded_history["time"])
153    te = np.asarray(recorded_history["te"])
154    re = np.asarray(recorded_history["re"])
155    be = np.asarray(recorded_history["be"])
156    se = np.asarray(recorded_history["se"])
157
158    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
159    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
160    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
161    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
162    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
163    energy_ax.legend()
164    # don't block
165    energy_fig.show()
166
167    if SAVE_FIGURE:
168        fig.savefig("butterfly.png")
169        energy_fig.savefig("energies.png")
170
171    plt.show()
172
173if SAVE_RESULTS:
174    import pickle
175
176    filename = "butterfly_data.dat"
177    file = open(filename, "wb")
178    pickle.dump(butterfly_rod, file)
179    file.close()

Helical Buckling#

  1__doc__ = """Helical buckling validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
  7    plot_helicalbuckling,
  8)
  9
 10
 11class HelicalBucklingSimulator(
 12    ea.BaseSystemCollection, ea.Constraints, ea.Damping, ea.Forcing
 13):
 14    pass
 15
 16
 17helicalbuckling_sim = HelicalBucklingSimulator()
 18
 19# Options
 20PLOT_FIGURE = True
 21SAVE_FIGURE = True
 22SAVE_RESULTS = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 100.0
 30base_radius = 0.35
 31base_area = np.pi * base_radius**2
 32density = 1.0 / (base_area)
 33nu = 0.01 / density / base_area
 34E = 1e6
 35slack = 3
 36number_of_rotations = 27
 37# For shear modulus of 1e5, nu is 99!
 38poisson_ratio = 9
 39shear_modulus = E / (poisson_ratio + 1.0)
 40shear_matrix = np.repeat(
 41    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
 42)
 43temp_bend_matrix = np.zeros((3, 3))
 44np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
 45bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
 46
 47shearable_rod = ea.CosseratRod.straight_rod(
 48    n_elem,
 49    start,
 50    direction,
 51    normal,
 52    base_length,
 53    base_radius,
 54    density,
 55    youngs_modulus=E,
 56    shear_modulus=shear_modulus,
 57)
 58# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
 59
 60shearable_rod.shear_matrix = shear_matrix
 61shearable_rod.bend_matrix = bend_matrix
 62
 63
 64helicalbuckling_sim.append(shearable_rod)
 65# add damping
 66dl = base_length / n_elem
 67dt = 1e-3 * dl
 68helicalbuckling_sim.dampen(shearable_rod).using(
 69    ea.AnalyticalLinearDamper,
 70    damping_constant=nu,
 71    time_step=dt,
 72)
 73
 74helicalbuckling_sim.constrain(shearable_rod).using(
 75    ea.HelicalBucklingBC,
 76    constrained_position_idx=(0, -1),
 77    constrained_director_idx=(0, -1),
 78    twisting_time=500,
 79    slack=slack,
 80    number_of_rotations=number_of_rotations,
 81)
 82
 83helicalbuckling_sim.finalize()
 84timestepper = ea.PositionVerlet()
 85shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
 86# timestepper = PEFRL()
 87
 88final_time = 10500.0
 89total_steps = int(final_time / dt)
 90print("Total steps", total_steps)
 91ea.integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
 92
 93if PLOT_FIGURE:
 94    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
 95
 96if SAVE_RESULTS:
 97    import pickle
 98
 99    filename = "HelicalBuckling_data.dat"
100    file = open(filename, "wb")
101    pickle.dump(shearable_rod, file)
102    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import os
  4import numpy as np
  5import elastica as ea
  6from numpy.typing import NDArray
  7from elastica.typing import RodType
  8
  9from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
 10    plot_snake_velocity,
 11    plot_video,
 12    compute_projected_velocity,
 13    plot_curvature,
 14)
 15
 16
 17class SnakeSimulator(
 18    ea.BaseSystemCollection,
 19    ea.Constraints,
 20    ea.Forcing,
 21    ea.Damping,
 22    ea.CallBacks,
 23    ea.Contact,
 24):
 25    pass
 26
 27
 28def run_snake(
 29    b_coeff: NDArray[np.float64],
 30    PLOT_FIGURE: bool = False,
 31    SAVE_FIGURE: bool = False,
 32    SAVE_VIDEO: bool = False,
 33    SAVE_RESULTS: bool = False,
 34) -> tuple[float, float, dict]:
 35    # Initialize the simulation class
 36    snake_sim = SnakeSimulator()
 37
 38    # Simulation parameters
 39    period = 2
 40    final_time = (11.0 + 0.01) * period
 41
 42    # setting up test params
 43    n_elem = 50
 44    start = np.zeros((3,))
 45    direction = np.array([0.0, 0.0, 1.0])
 46    normal = np.array([0.0, 1.0, 0.0])
 47    base_length = 0.35
 48    base_radius = base_length * 0.011
 49    density = 1000
 50    E = 1e6
 51    poisson_ratio = 0.5
 52    shear_modulus = E / (poisson_ratio + 1.0)
 53
 54    shearable_rod = ea.CosseratRod.straight_rod(
 55        n_elem,
 56        start,
 57        direction,
 58        normal,
 59        base_length,
 60        base_radius,
 61        density,
 62        youngs_modulus=E,
 63        shear_modulus=shear_modulus,
 64    )
 65
 66    snake_sim.append(shearable_rod)
 67
 68    # Add gravitational forces
 69    gravitational_acc = -9.80665
 70    snake_sim.add_forcing_to(shearable_rod).using(
 71        ea.GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 72    )
 73
 74    # Add muscle torques
 75    wave_length = b_coeff[-1]
 76    snake_sim.add_forcing_to(shearable_rod).using(
 77        ea.MuscleTorques,
 78        base_length=base_length,
 79        b_coeff=b_coeff[:-1],
 80        period=period,
 81        wave_number=2.0 * np.pi / (wave_length),
 82        phase_shift=0.0,
 83        rest_lengths=shearable_rod.rest_lengths,
 84        ramp_up_time=period,
 85        direction=normal,
 86        with_spline=True,
 87    )
 88
 89    # Add friction forces
 90    ground_plane = ea.Plane(
 91        plane_origin=np.array([0.0, -base_radius, 0.0]), plane_normal=normal
 92    )
 93    snake_sim.append(ground_plane)
 94    slip_velocity_tol = 1e-8
 95    froude = 0.1
 96    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 97    kinetic_mu_array = np.array(
 98        [mu, 1.5 * mu, 2.0 * mu]
 99    )  # [forward, backward, sideways]
100    static_mu_array = np.zeros(kinetic_mu_array.shape)
101    snake_sim.detect_contact_between(shearable_rod, ground_plane).using(
102        ea.RodPlaneContactWithAnisotropicFriction,
103        k=1.0,
104        nu=1e-6,
105        slip_velocity_tol=slip_velocity_tol,
106        static_mu_array=static_mu_array,
107        kinetic_mu_array=kinetic_mu_array,
108    )
109
110    # add damping
111    damping_constant = 2e-3
112    time_step = 1e-4
113    snake_sim.dampen(shearable_rod).using(
114        ea.AnalyticalLinearDamper,
115        damping_constant=damping_constant,
116        time_step=time_step,
117    )
118
119    total_steps = int(final_time / time_step)
120    rendering_fps = 60
121    step_skip = int(1.0 / (rendering_fps * time_step))
122
123    # Add call backs
124    class ContinuumSnakeCallBack(ea.CallBackBaseClass):
125        """
126        Call back function for continuum snake
127        """
128
129        def __init__(self, step_skip: int, callback_params: dict) -> None:
130            ea.CallBackBaseClass.__init__(self)
131            self.every = step_skip
132            self.callback_params = callback_params
133
134        def make_callback(
135            self, system: ea.CosseratRod, time: float, current_step: int
136        ) -> None:
137
138            if current_step % self.every == 0:
139
140                self.callback_params["time"].append(time)
141                self.callback_params["step"].append(current_step)
142                self.callback_params["position"].append(
143                    system.position_collection.copy()
144                )
145                self.callback_params["velocity"].append(
146                    system.velocity_collection.copy()
147                )
148                self.callback_params["avg_velocity"].append(
149                    system.compute_velocity_center_of_mass()
150                )
151
152                self.callback_params["center_of_mass"].append(
153                    system.compute_position_center_of_mass()
154                )
155                self.callback_params["curvature"].append(system.kappa.copy())
156
157                self.callback_params["tangents"].append(system.tangents.copy())
158
159                self.callback_params["friction"].append(
160                    self.get_slip_velocity(system).copy()
161                )
162
163                return
164
165        def get_slip_velocity(self, system: RodType) -> NDArray[np.float64]:
166            from elastica.contact_utils import (
167                _find_slipping_elements,
168                _node_to_element_velocity,
169            )
170            from elastica._linalg import _batch_product_k_ik_to_ik, _batch_dot
171
172            axial_direction = system.tangents
173            element_velocity = _node_to_element_velocity(
174                mass=system.mass, node_velocity_collection=system.velocity_collection
175            )
176            velocity_mag_along_axial_direction = _batch_dot(
177                element_velocity, axial_direction
178            )
179            velocity_along_axial_direction = _batch_product_k_ik_to_ik(
180                velocity_mag_along_axial_direction, axial_direction
181            )
182            slip_function_along_axial_direction = _find_slipping_elements(
183                velocity_along_axial_direction, slip_velocity_tol
184            )
185            return slip_function_along_axial_direction
186
187    pp_list: dict[str, list] = ea.defaultdict(list)
188    snake_sim.collect_diagnostics(shearable_rod).using(
189        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
190    )
191
192    snake_sim.finalize()
193
194    timestepper = ea.PositionVerlet()
195    ea.integrate(timestepper, snake_sim, final_time, total_steps)
196
197    if PLOT_FIGURE:
198        filename_plot = "continuum_snake_velocity.png"
199        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
200        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
201
202        if SAVE_VIDEO:
203            filename_video = "continuum_snake.mp4"
204            plot_video(
205                pp_list,
206                video_name=filename_video,
207                fps=rendering_fps,
208                xlim=(0, 4),
209                ylim=(-1, 1),
210            )
211
212    if SAVE_RESULTS:
213        import pickle
214
215        filename = "continuum_snake.dat"
216        file = open(filename, "wb")
217        pickle.dump(pp_list, file)
218        file.close()
219
220    # Compute the average forward velocity. These will be used for optimization.
221    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
222
223    return avg_forward, avg_lateral, pp_list
224
225
226if __name__ == "__main__":
227
228    # Options
229    PLOT_FIGURE = True
230    SAVE_FIGURE = True
231    SAVE_VIDEO = True
232    SAVE_RESULTS = False
233    CMA_OPTION = False
234
235    if CMA_OPTION:
236        import cma
237
238        SAVE_OPTIMIZED_COEFFICIENTS = False
239
240        def optimize_snake(spline_coefficient: NDArray[np.float64]) -> float:
241            [avg_forward, _, _] = run_snake(
242                spline_coefficient,
243                PLOT_FIGURE=False,
244                SAVE_FIGURE=False,
245                SAVE_VIDEO=False,
246                SAVE_RESULTS=False,
247            )
248            return -avg_forward
249
250        # Optimize snake for forward velocity. In cma.fmin first input is function
251        # to be optimized, second input is initial guess for coefficients you are optimizing
252        # for and third input is standard deviation you initially set.
253        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
254
255        # Save the optimized coefficients to a file
256        filename_data = "optimized_coefficients.txt"
257        if SAVE_OPTIMIZED_COEFFICIENTS:
258            assert filename_data != "", "provide a file name for coefficients"
259            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
260
261    else:
262        # Add muscle forces on the rod
263        if os.path.exists("optimized_coefficients.txt"):
264            t_coeff_optimized = np.genfromtxt(
265                "optimized_coefficients.txt", delimiter=","
266            )
267        else:
268            wave_length = 1.0
269            t_coeff_optimized = np.array(
270                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
271            )
272            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
273
274        # run the simulation
275        [avg_forward, avg_lateral, pp_list] = run_snake(
276            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
277        )
278
279        print("average forward velocity:", avg_forward)
280        print("average forward lateral:", avg_lateral)