Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1""" Axial stretching test-case
  2
  3    Assume we have a rod lying aligned in the x-direction, with high internal
  4    damping.
  5
  6    We fix one end (say, the left end) of the rod to a wall. On the right
  7    end we apply a force directed axially pulling the rods tip. Linear
  8    theory (assuming small displacements) predict that the net displacement
  9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10    usual meaning (the rod is just a linear spring). We compare our results
 11    with the above result.
 12
 13    We can "improve" the theory by having a better estimate for the rod's
 14    spring constant by assuming that it equilibriates under the new position,
 15    with
 16    Δx = F * (L + Δx)/ (A * E)
 17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18    this position.
 19
 20    Note that if the damping is not high, the rod oscillates about the eventual
 21    resting position (and this agrees with the theoretical predictions without
 22    any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24    isort:skip_file
 25"""
 26import numpy as np
 27from matplotlib import pyplot as plt
 28
 29import elastica as ea
 30
 31
 32class StretchingBeamSimulator(
 33    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
 34):
 35    pass
 36
 37
 38stretch_sim = StretchingBeamSimulator()
 39final_time = 200.0
 40
 41# Options
 42PLOT_FIGURE = True
 43SAVE_FIGURE = False
 44SAVE_RESULTS = False
 45
 46# setting up test params
 47n_elem = 19
 48start = np.zeros((3,))
 49direction = np.array([1.0, 0.0, 0.0])
 50normal = np.array([0.0, 1.0, 0.0])
 51base_length = 1.0
 52base_radius = 0.025
 53base_area = np.pi * base_radius ** 2
 54density = 1000
 55youngs_modulus = 1e4
 56# For shear modulus of 1e4, nu is 99!
 57poisson_ratio = 0.5
 58shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 59
 60stretchable_rod = ea.CosseratRod.straight_rod(
 61    n_elem,
 62    start,
 63    direction,
 64    normal,
 65    base_length,
 66    base_radius,
 67    density,
 68    youngs_modulus=youngs_modulus,
 69    shear_modulus=shear_modulus,
 70)
 71
 72stretch_sim.append(stretchable_rod)
 73stretch_sim.constrain(stretchable_rod).using(
 74    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 75)
 76
 77end_force_x = 1.0
 78end_force = np.array([end_force_x, 0.0, 0.0])
 79stretch_sim.add_forcing_to(stretchable_rod).using(
 80    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 81)
 82
 83# add damping
 84dl = base_length / n_elem
 85dt = 0.1 * dl
 86damping_constant = 0.1
 87stretch_sim.dampen(stretchable_rod).using(
 88    ea.AnalyticalLinearDamper,
 89    damping_constant=damping_constant,
 90    time_step=dt,
 91)
 92
 93# Add call backs
 94class AxialStretchingCallBack(ea.CallBackBaseClass):
 95    """
 96    Tracks the velocity norms of the rod
 97    """
 98
 99    def __init__(self, step_skip: int, callback_params: dict):
100        ea.CallBackBaseClass.__init__(self)
101        self.every = step_skip
102        self.callback_params = callback_params
103
104    def make_callback(self, system, time, current_step: int):
105
106        if current_step % self.every == 0:
107
108            self.callback_params["time"].append(time)
109            # Collect only x
110            self.callback_params["position"].append(
111                system.position_collection[0, -1].copy()
112            )
113            self.callback_params["velocity_norms"].append(
114                np.linalg.norm(system.velocity_collection.copy())
115            )
116            return
117
118
119recorded_history = ea.defaultdict(list)
120stretch_sim.collect_diagnostics(stretchable_rod).using(
121    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
122)
123
124stretch_sim.finalize()
125timestepper = ea.PositionVerlet()
126# timestepper = PEFRL()
127
128total_steps = int(final_time / dt)
129print("Total steps", total_steps)
130ea.integrate(timestepper, stretch_sim, final_time, total_steps)
131
132if PLOT_FIGURE:
133    # First-order theory with base-length
134    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
135    # First-order theory with modified-length, gives better estimates
136    expected_tip_disp_improved = (
137        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
138    )
139
140    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
141    ax = fig.add_subplot(111)
142    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
143    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
144    ax.hlines(
145        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
146    )
147    if SAVE_FIGURE:
148        fig.savefig("axial_stretching.pdf")
149    plt.show()
150
151if SAVE_RESULTS:
152    import pickle
153
154    filename = "axial_stretching_data.dat"
155    file = open(filename, "wb")
156    pickle.dump(stretchable_rod, file)
157    file.close()
158
159    tv = (
160        np.asarray(recorded_history["time"]),
161        np.asarray(recorded_history["velocity_norms"]),
162    )
163
164    def as_time_series(v):
165        return v.T
166
167    np.savetxt(
168        "velocity_norms.csv",
169        as_time_series(np.stack(tv)),
170        delimiter=",",
171    )

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
  7
  8
  9class TimoshenkoBeamSimulator(
 10    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.CallBacks, ea.Damping
 11):
 12    pass
 13
 14
 15timoshenko_sim = TimoshenkoBeamSimulator()
 16final_time = 5000.0
 17
 18# Options
 19PLOT_FIGURE = True
 20SAVE_FIGURE = True
 21SAVE_RESULTS = False
 22ADD_UNSHEARABLE_ROD = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 3.0
 30base_radius = 0.25
 31base_area = np.pi * base_radius ** 2
 32density = 5000
 33nu = 0.1 / 7 / density / base_area
 34E = 1e6
 35# For shear modulus of 1e4, nu is 99!
 36poisson_ratio = 99
 37shear_modulus = E / (poisson_ratio + 1.0)
 38
 39shearable_rod = ea.CosseratRod.straight_rod(
 40    n_elem,
 41    start,
 42    direction,
 43    normal,
 44    base_length,
 45    base_radius,
 46    density,
 47    youngs_modulus=E,
 48    shear_modulus=shear_modulus,
 49)
 50
 51timoshenko_sim.append(shearable_rod)
 52# add damping
 53dl = base_length / n_elem
 54dt = 0.07 * dl
 55timoshenko_sim.dampen(shearable_rod).using(
 56    ea.AnalyticalLinearDamper,
 57    damping_constant=nu,
 58    time_step=dt,
 59)
 60
 61timoshenko_sim.constrain(shearable_rod).using(
 62    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 63)
 64
 65end_force = np.array([-15.0, 0.0, 0.0])
 66timoshenko_sim.add_forcing_to(shearable_rod).using(
 67    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 68)
 69
 70
 71if ADD_UNSHEARABLE_ROD:
 72    # Start into the plane
 73    unshearable_start = np.array([0.0, -1.0, 0.0])
 74    shear_modulus = E / (-0.7 + 1.0)
 75    unshearable_rod = ea.CosseratRod.straight_rod(
 76        n_elem,
 77        unshearable_start,
 78        direction,
 79        normal,
 80        base_length,
 81        base_radius,
 82        density,
 83        youngs_modulus=E,
 84        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 85        shear_modulus=shear_modulus,
 86    )
 87
 88    timoshenko_sim.append(unshearable_rod)
 89
 90    # add damping
 91    timoshenko_sim.dampen(unshearable_rod).using(
 92        ea.AnalyticalLinearDamper,
 93        damping_constant=nu,
 94        time_step=dt,
 95    )
 96    timoshenko_sim.constrain(unshearable_rod).using(
 97        ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 98    )
 99    timoshenko_sim.add_forcing_to(unshearable_rod).using(
100        ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
101    )
102
103# Add call backs
104class VelocityCallBack(ea.CallBackBaseClass):
105    """
106    Tracks the velocity norms of the rod
107    """
108
109    def __init__(self, step_skip: int, callback_params: dict):
110        ea.CallBackBaseClass.__init__(self)
111        self.every = step_skip
112        self.callback_params = callback_params
113
114    def make_callback(self, system, time, current_step: int):
115
116        if current_step % self.every == 0:
117
118            self.callback_params["time"].append(time)
119            # Collect x
120            self.callback_params["velocity_norms"].append(
121                np.linalg.norm(system.velocity_collection.copy())
122            )
123            return
124
125
126recorded_history = ea.defaultdict(list)
127timoshenko_sim.collect_diagnostics(shearable_rod).using(
128    VelocityCallBack, step_skip=500, callback_params=recorded_history
129)
130
131timoshenko_sim.finalize()
132timestepper = ea.PositionVerlet()
133# timestepper = PEFRL()
134
135total_steps = int(final_time / dt)
136print("Total steps", total_steps)
137ea.integrate(timestepper, timoshenko_sim, final_time, total_steps)
138
139if PLOT_FIGURE:
140    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
141
142if SAVE_RESULTS:
143    import pickle
144
145    filename = "Timoshenko_beam_data.dat"
146    file = open(filename, "wb")
147    pickle.dump(shearable_rod, file)
148    file.close()
149
150    tv = (
151        np.asarray(recorded_history["time"]),
152        np.asarray(recorded_history["velocity_norms"]),
153    )
154
155    def as_time_series(v):
156        return v.T
157
158    np.savetxt(
159        "velocity_norms.csv",
160        as_time_series(np.stack(tv)),
161        delimiter=",",
162    )

Butterfly#

  1import numpy as np
  2from matplotlib import pyplot as plt
  3from matplotlib.colors import to_rgb
  4
  5
  6import elastica as ea
  7from elastica.utils import MaxDimension
  8
  9
 10class ButterflySimulator(ea.BaseSystemCollection, ea.CallBacks):
 11    pass
 12
 13
 14butterfly_sim = ButterflySimulator()
 15final_time = 40.0
 16
 17# Options
 18PLOT_FIGURE = True
 19SAVE_FIGURE = True
 20SAVE_RESULTS = True
 21ADD_UNSHEARABLE_ROD = False
 22
 23# setting up test params
 24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 25n_elem = 4  # Change based on requirements, but be careful
 26n_elem += n_elem % 2
 27half_n_elem = n_elem // 2
 28
 29origin = np.zeros((3, 1))
 30angle_of_inclination = np.deg2rad(45.0)
 31
 32# in-plane
 33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 35
 36# out-of-plane
 37normal = np.array([0.0, 1.0, 0.0])
 38
 39total_length = 3.0
 40base_radius = 0.25
 41base_area = np.pi * base_radius ** 2
 42density = 5000
 43youngs_modulus = 1e4
 44poisson_ratio = 0.5
 45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 46
 47positions = np.empty((MaxDimension.value(), n_elem + 1))
 48dl = total_length / n_elem
 49
 50# First half of positions stem from slope angle_of_inclination
 51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 53    np.cos(angle_of_inclination) * horizontal_direction
 54    + np.sin(angle_of_inclination) * vertical_direction
 55)
 56positions[..., half_n_elem:] = positions[
 57    ..., half_n_elem : half_n_elem + 1
 58] + dl * first_half * (
 59    np.cos(angle_of_inclination) * horizontal_direction
 60    - np.sin(angle_of_inclination) * vertical_direction
 61)
 62
 63butterfly_rod = ea.CosseratRod.straight_rod(
 64    n_elem,
 65    start=origin.reshape(3),
 66    direction=np.array([0.0, 0.0, 1.0]),
 67    normal=normal,
 68    base_length=total_length,
 69    base_radius=base_radius,
 70    density=density,
 71    youngs_modulus=youngs_modulus,
 72    shear_modulus=shear_modulus,
 73    position=positions,
 74)
 75
 76butterfly_sim.append(butterfly_rod)
 77
 78# Add call backs
 79class VelocityCallBack(ea.CallBackBaseClass):
 80    """
 81    Call back function for continuum snake
 82    """
 83
 84    def __init__(self, step_skip: int, callback_params: dict):
 85        ea.CallBackBaseClass.__init__(self)
 86        self.every = step_skip
 87        self.callback_params = callback_params
 88
 89    def make_callback(self, system, time, current_step: int):
 90
 91        if current_step % self.every == 0:
 92
 93            self.callback_params["time"].append(time)
 94            # Collect x
 95            self.callback_params["position"].append(system.position_collection.copy())
 96            # Collect energies as well
 97            self.callback_params["te"].append(system.compute_translational_energy())
 98            self.callback_params["re"].append(system.compute_rotational_energy())
 99            self.callback_params["se"].append(system.compute_shear_energy())
100            self.callback_params["be"].append(system.compute_bending_energy())
101            return
102
103
104recorded_history = ea.defaultdict(list)
105# initially record history
106recorded_history["time"].append(0.0)
107recorded_history["position"].append(butterfly_rod.position_collection.copy())
108recorded_history["te"].append(butterfly_rod.compute_translational_energy())
109recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
110recorded_history["se"].append(butterfly_rod.compute_shear_energy())
111recorded_history["be"].append(butterfly_rod.compute_bending_energy())
112
113butterfly_sim.collect_diagnostics(butterfly_rod).using(
114    VelocityCallBack, step_skip=100, callback_params=recorded_history
115)
116
117
118butterfly_sim.finalize()
119timestepper = ea.PositionVerlet()
120# timestepper = PEFRL()
121
122dt = 0.01 * dl
123total_steps = int(final_time / dt)
124print("Total steps", total_steps)
125ea.integrate(timestepper, butterfly_sim, final_time, total_steps)
126
127if PLOT_FIGURE:
128    # Plot the histories
129    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
130    ax = fig.add_subplot(111)
131    positions = recorded_history["position"]
132    # record first position
133    first_position = positions.pop(0)
134    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
135    n_positions = len(positions)
136    for i, pos in enumerate(positions):
137        alpha = np.exp(i / n_positions - 1)
138        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
139    # final position is also separate
140    last_position = positions.pop()
141    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
142    # don't block
143    fig.show()
144
145    # Plot the energies
146    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
147    energy_ax = energy_fig.add_subplot(111)
148    times = np.asarray(recorded_history["time"])
149    te = np.asarray(recorded_history["te"])
150    re = np.asarray(recorded_history["re"])
151    be = np.asarray(recorded_history["be"])
152    se = np.asarray(recorded_history["se"])
153
154    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
155    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
156    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
157    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
158    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
159    energy_ax.legend()
160    # don't block
161    energy_fig.show()
162
163    if SAVE_FIGURE:
164        fig.savefig("butterfly.png")
165        energy_fig.savefig("energies.png")
166
167    plt.show()
168
169if SAVE_RESULTS:
170    import pickle
171
172    filename = "butterfly_data.dat"
173    file = open(filename, "wb")
174    pickle.dump(butterfly_rod, file)
175    file.close()

Helical Buckling#

  1__doc__ = """Helical buckling validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
  7    plot_helicalbuckling,
  8)
  9
 10
 11class HelicalBucklingSimulator(
 12    ea.BaseSystemCollection, ea.Constraints, ea.Damping, ea.Forcing
 13):
 14    pass
 15
 16
 17helicalbuckling_sim = HelicalBucklingSimulator()
 18
 19# Options
 20PLOT_FIGURE = True
 21SAVE_FIGURE = True
 22SAVE_RESULTS = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 100.0
 30base_radius = 0.35
 31base_area = np.pi * base_radius ** 2
 32density = 1.0 / (base_area)
 33nu = 0.01 / density / base_area
 34E = 1e6
 35slack = 3
 36number_of_rotations = 27
 37# For shear modulus of 1e5, nu is 99!
 38poisson_ratio = 9
 39shear_modulus = E / (poisson_ratio + 1.0)
 40shear_matrix = np.repeat(
 41    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
 42)
 43temp_bend_matrix = np.zeros((3, 3))
 44np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
 45bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
 46
 47shearable_rod = ea.CosseratRod.straight_rod(
 48    n_elem,
 49    start,
 50    direction,
 51    normal,
 52    base_length,
 53    base_radius,
 54    density,
 55    youngs_modulus=E,
 56    shear_modulus=shear_modulus,
 57)
 58# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
 59
 60shearable_rod.shear_matrix = shear_matrix
 61shearable_rod.bend_matrix = bend_matrix
 62
 63
 64helicalbuckling_sim.append(shearable_rod)
 65# add damping
 66dl = base_length / n_elem
 67dt = 1e-3 * dl
 68helicalbuckling_sim.dampen(shearable_rod).using(
 69    ea.AnalyticalLinearDamper,
 70    damping_constant=nu,
 71    time_step=dt,
 72)
 73
 74helicalbuckling_sim.constrain(shearable_rod).using(
 75    ea.HelicalBucklingBC,
 76    constrained_position_idx=(0, -1),
 77    constrained_director_idx=(0, -1),
 78    twisting_time=500,
 79    slack=slack,
 80    number_of_rotations=number_of_rotations,
 81)
 82
 83helicalbuckling_sim.finalize()
 84timestepper = ea.PositionVerlet()
 85shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
 86# timestepper = PEFRL()
 87
 88final_time = 10500.0
 89total_steps = int(final_time / dt)
 90print("Total steps", total_steps)
 91ea.integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
 92
 93if PLOT_FIGURE:
 94    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
 95
 96if SAVE_RESULTS:
 97    import pickle
 98
 99    filename = "HelicalBuckling_data.dat"
100    file = open(filename, "wb")
101    pickle.dump(shearable_rod, file)
102    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import os
  4import numpy as np
  5import elastica as ea
  6
  7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
  8    plot_snake_velocity,
  9    plot_video,
 10    compute_projected_velocity,
 11    plot_curvature,
 12)
 13
 14
 15class SnakeSimulator(
 16    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
 17):
 18    pass
 19
 20
 21def run_snake(
 22    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
 23):
 24    # Initialize the simulation class
 25    snake_sim = SnakeSimulator()
 26
 27    # Simulation parameters
 28    period = 2
 29    final_time = (11.0 + 0.01) * period
 30
 31    # setting up test params
 32    n_elem = 50
 33    start = np.zeros((3,))
 34    direction = np.array([0.0, 0.0, 1.0])
 35    normal = np.array([0.0, 1.0, 0.0])
 36    base_length = 0.35
 37    base_radius = base_length * 0.011
 38    density = 1000
 39    E = 1e6
 40    poisson_ratio = 0.5
 41    shear_modulus = E / (poisson_ratio + 1.0)
 42
 43    shearable_rod = ea.CosseratRod.straight_rod(
 44        n_elem,
 45        start,
 46        direction,
 47        normal,
 48        base_length,
 49        base_radius,
 50        density,
 51        youngs_modulus=E,
 52        shear_modulus=shear_modulus,
 53    )
 54
 55    snake_sim.append(shearable_rod)
 56
 57    # Add gravitational forces
 58    gravitational_acc = -9.80665
 59    snake_sim.add_forcing_to(shearable_rod).using(
 60        ea.GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 61    )
 62
 63    # Add muscle torques
 64    wave_length = b_coeff[-1]
 65    snake_sim.add_forcing_to(shearable_rod).using(
 66        ea.MuscleTorques,
 67        base_length=base_length,
 68        b_coeff=b_coeff[:-1],
 69        period=period,
 70        wave_number=2.0 * np.pi / (wave_length),
 71        phase_shift=0.0,
 72        rest_lengths=shearable_rod.rest_lengths,
 73        ramp_up_time=period,
 74        direction=normal,
 75        with_spline=True,
 76    )
 77
 78    # Add friction forces
 79    origin_plane = np.array([0.0, -base_radius, 0.0])
 80    normal_plane = normal
 81    slip_velocity_tol = 1e-8
 82    froude = 0.1
 83    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 84    kinetic_mu_array = np.array(
 85        [mu, 1.5 * mu, 2.0 * mu]
 86    )  # [forward, backward, sideways]
 87    static_mu_array = np.zeros(kinetic_mu_array.shape)
 88    snake_sim.add_forcing_to(shearable_rod).using(
 89        ea.AnisotropicFrictionalPlane,
 90        k=1.0,
 91        nu=1e-6,
 92        plane_origin=origin_plane,
 93        plane_normal=normal_plane,
 94        slip_velocity_tol=slip_velocity_tol,
 95        static_mu_array=static_mu_array,
 96        kinetic_mu_array=kinetic_mu_array,
 97    )
 98
 99    # add damping
100    damping_constant = 2e-3
101    time_step = 1e-4
102    snake_sim.dampen(shearable_rod).using(
103        ea.AnalyticalLinearDamper,
104        damping_constant=damping_constant,
105        time_step=time_step,
106    )
107
108    total_steps = int(final_time / time_step)
109    rendering_fps = 60
110    step_skip = int(1.0 / (rendering_fps * time_step))
111
112    # Add call backs
113    class ContinuumSnakeCallBack(ea.CallBackBaseClass):
114        """
115        Call back function for continuum snake
116        """
117
118        def __init__(self, step_skip: int, callback_params: dict):
119            ea.CallBackBaseClass.__init__(self)
120            self.every = step_skip
121            self.callback_params = callback_params
122
123        def make_callback(self, system, time, current_step: int):
124
125            if current_step % self.every == 0:
126
127                self.callback_params["time"].append(time)
128                self.callback_params["step"].append(current_step)
129                self.callback_params["position"].append(
130                    system.position_collection.copy()
131                )
132                self.callback_params["velocity"].append(
133                    system.velocity_collection.copy()
134                )
135                self.callback_params["avg_velocity"].append(
136                    system.compute_velocity_center_of_mass()
137                )
138
139                self.callback_params["center_of_mass"].append(
140                    system.compute_position_center_of_mass()
141                )
142                self.callback_params["curvature"].append(system.kappa.copy())
143
144                return
145
146    pp_list = ea.defaultdict(list)
147    snake_sim.collect_diagnostics(shearable_rod).using(
148        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
149    )
150
151    snake_sim.finalize()
152
153    timestepper = ea.PositionVerlet()
154    ea.integrate(timestepper, snake_sim, final_time, total_steps)
155
156    if PLOT_FIGURE:
157        filename_plot = "continuum_snake_velocity.png"
158        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
159        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
160
161        if SAVE_VIDEO:
162            filename_video = "continuum_snake.mp4"
163            plot_video(
164                pp_list,
165                video_name=filename_video,
166                fps=rendering_fps,
167                xlim=(0, 4),
168                ylim=(-1, 1),
169            )
170
171    if SAVE_RESULTS:
172        import pickle
173
174        filename = "continuum_snake.dat"
175        file = open(filename, "wb")
176        pickle.dump(pp_list, file)
177        file.close()
178
179    # Compute the average forward velocity. These will be used for optimization.
180    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
181
182    return avg_forward, avg_lateral, pp_list
183
184
185if __name__ == "__main__":
186
187    # Options
188    PLOT_FIGURE = True
189    SAVE_FIGURE = True
190    SAVE_VIDEO = True
191    SAVE_RESULTS = False
192    CMA_OPTION = False
193
194    if CMA_OPTION:
195        import cma
196
197        SAVE_OPTIMIZED_COEFFICIENTS = False
198
199        def optimize_snake(spline_coefficient):
200            [avg_forward, _, _] = run_snake(
201                spline_coefficient,
202                PLOT_FIGURE=False,
203                SAVE_FIGURE=False,
204                SAVE_VIDEO=False,
205                SAVE_RESULTS=False,
206            )
207            return -avg_forward
208
209        # Optimize snake for forward velocity. In cma.fmin first input is function
210        # to be optimized, second input is initial guess for coefficients you are optimizing
211        # for and third input is standard deviation you initially set.
212        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
213
214        # Save the optimized coefficients to a file
215        filename_data = "optimized_coefficients.txt"
216        if SAVE_OPTIMIZED_COEFFICIENTS:
217            assert filename_data != "", "provide a file name for coefficients"
218            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
219
220    else:
221        # Add muscle forces on the rod
222        if os.path.exists("optimized_coefficients.txt"):
223            t_coeff_optimized = np.genfromtxt(
224                "optimized_coefficients.txt", delimiter=","
225            )
226        else:
227            wave_length = 1.0
228            t_coeff_optimized = np.array(
229                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
230            )
231            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
232
233        # run the simulation
234        [avg_forward, avg_lateral, pp_list] = run_snake(
235            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
236        )
237
238        print("average forward velocity:", avg_forward)
239        print("average forward lateral:", avg_lateral)