Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1""" Axial stretching test-case
  2
  3    Assume we have a rod lying aligned in the x-direction, with high internal
  4    damping.
  5
  6    We fix one end (say, the left end) of the rod to a wall. On the right
  7    end we apply a force directed axially pulling the rods tip. Linear
  8    theory (assuming small displacements) predict that the net displacement
  9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10    usual meaning (the rod is just a linear spring). We compare our results
 11    with the above result.
 12
 13    We can "improve" the theory by having a better estimate for the rod's
 14    spring constant by assuming that it equilibriates under the new position,
 15    with
 16    Δx = F * (L + Δx)/ (A * E)
 17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18    this position.
 19
 20    Note that if the damping is not high, the rod oscillates about the eventual
 21    resting position (and this agrees with the theoretical predictions without
 22    any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24    isort:skip_file
 25"""
 26import numpy as np
 27from matplotlib import pyplot as plt
 28
 29import elastica as ea
 30
 31
 32class StretchingBeamSimulator(
 33    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
 34):
 35    pass
 36
 37
 38stretch_sim = StretchingBeamSimulator()
 39final_time = 200.0
 40
 41# Options
 42PLOT_FIGURE = True
 43SAVE_FIGURE = False
 44SAVE_RESULTS = False
 45
 46# setting up test params
 47n_elem = 19
 48start = np.zeros((3,))
 49direction = np.array([1.0, 0.0, 0.0])
 50normal = np.array([0.0, 1.0, 0.0])
 51base_length = 1.0
 52base_radius = 0.025
 53base_area = np.pi * base_radius ** 2
 54density = 1000
 55youngs_modulus = 1e4
 56# For shear modulus of 1e4, nu is 99!
 57poisson_ratio = 0.5
 58shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 59
 60stretchable_rod = ea.CosseratRod.straight_rod(
 61    n_elem,
 62    start,
 63    direction,
 64    normal,
 65    base_length,
 66    base_radius,
 67    density,
 68    youngs_modulus=youngs_modulus,
 69    shear_modulus=shear_modulus,
 70)
 71
 72stretch_sim.append(stretchable_rod)
 73stretch_sim.constrain(stretchable_rod).using(
 74    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 75)
 76
 77end_force_x = 1.0
 78end_force = np.array([end_force_x, 0.0, 0.0])
 79stretch_sim.add_forcing_to(stretchable_rod).using(
 80    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 81)
 82
 83# add damping
 84dl = base_length / n_elem
 85dt = 0.1 * dl
 86damping_constant = 0.1
 87stretch_sim.dampen(stretchable_rod).using(
 88    ea.AnalyticalLinearDamper,
 89    damping_constant=damping_constant,
 90    time_step=dt,
 91)
 92
 93# Add call backs
 94class AxialStretchingCallBack(ea.CallBackBaseClass):
 95    """
 96    Tracks the velocity norms of the rod
 97    """
 98
 99    def __init__(self, step_skip: int, callback_params: dict):
100        ea.CallBackBaseClass.__init__(self)
101        self.every = step_skip
102        self.callback_params = callback_params
103
104    def make_callback(self, system, time, current_step: int):
105
106        if current_step % self.every == 0:
107
108            self.callback_params["time"].append(time)
109            # Collect only x
110            self.callback_params["position"].append(
111                system.position_collection[0, -1].copy()
112            )
113            self.callback_params["velocity_norms"].append(
114                np.linalg.norm(system.velocity_collection.copy())
115            )
116            return
117
118
119recorded_history = ea.defaultdict(list)
120stretch_sim.collect_diagnostics(stretchable_rod).using(
121    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
122)
123
124stretch_sim.finalize()
125timestepper = ea.PositionVerlet()
126# timestepper = PEFRL()
127
128total_steps = int(final_time / dt)
129print("Total steps", total_steps)
130ea.integrate(timestepper, stretch_sim, final_time, total_steps)
131
132if PLOT_FIGURE:
133    # First-order theory with base-length
134    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
135    # First-order theory with modified-length, gives better estimates
136    expected_tip_disp_improved = (
137        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
138    )
139
140    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
141    ax = fig.add_subplot(111)
142    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
143    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
144    ax.hlines(
145        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
146    )
147    if SAVE_FIGURE:
148        fig.savefig("axial_stretching.pdf")
149    plt.show()
150
151if SAVE_RESULTS:
152    import pickle
153
154    filename = "axial_stretching_data.dat"
155    file = open(filename, "wb")
156    pickle.dump(stretchable_rod, file)
157    file.close()
158
159    tv = (
160        np.asarray(recorded_history["time"]),
161        np.asarray(recorded_history["velocity_norms"]),
162    )
163
164    def as_time_series(v):
165        return v.T
166
167    np.savetxt(
168        "velocity_norms.csv",
169        as_time_series(np.stack(tv)),
170        delimiter=",",
171    )

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
  7
  8
  9class TimoshenkoBeamSimulator(
 10    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.CallBacks, ea.Damping
 11):
 12    pass
 13
 14
 15timoshenko_sim = TimoshenkoBeamSimulator()
 16final_time = 5000.0
 17
 18# Options
 19PLOT_FIGURE = True
 20SAVE_FIGURE = True
 21SAVE_RESULTS = False
 22ADD_UNSHEARABLE_ROD = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 3.0
 30base_radius = 0.25
 31base_area = np.pi * base_radius ** 2
 32density = 5000
 33nu = 0.1 / 7 / density / base_area
 34E = 1e6
 35# For shear modulus of 1e4, nu is 99!
 36poisson_ratio = 99
 37shear_modulus = E / (poisson_ratio + 1.0)
 38
 39shearable_rod = ea.CosseratRod.straight_rod(
 40    n_elem,
 41    start,
 42    direction,
 43    normal,
 44    base_length,
 45    base_radius,
 46    density,
 47    youngs_modulus=E,
 48    shear_modulus=shear_modulus,
 49)
 50
 51timoshenko_sim.append(shearable_rod)
 52# add damping
 53dl = base_length / n_elem
 54dt = 0.07 * dl
 55timoshenko_sim.dampen(shearable_rod).using(
 56    ea.AnalyticalLinearDamper,
 57    damping_constant=nu,
 58    time_step=dt,
 59)
 60
 61timoshenko_sim.constrain(shearable_rod).using(
 62    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 63)
 64
 65end_force = np.array([-15.0, 0.0, 0.0])
 66timoshenko_sim.add_forcing_to(shearable_rod).using(
 67    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 68)
 69
 70
 71if ADD_UNSHEARABLE_ROD:
 72    # Start into the plane
 73    unshearable_start = np.array([0.0, -1.0, 0.0])
 74    shear_modulus = E / (-0.7 + 1.0)
 75    unshearable_rod = ea.CosseratRod.straight_rod(
 76        n_elem,
 77        unshearable_start,
 78        direction,
 79        normal,
 80        base_length,
 81        base_radius,
 82        density,
 83        youngs_modulus=E,
 84        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 85        shear_modulus=shear_modulus,
 86    )
 87
 88    timoshenko_sim.append(unshearable_rod)
 89
 90    # add damping
 91    timoshenko_sim.dampen(unshearable_rod).using(
 92        ea.AnalyticalLinearDamper,
 93        damping_constant=nu,
 94        time_step=dt,
 95    )
 96    timoshenko_sim.constrain(unshearable_rod).using(
 97        ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 98    )
 99    timoshenko_sim.add_forcing_to(unshearable_rod).using(
100        ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
101    )
102
103# Add call backs
104class VelocityCallBack(ea.CallBackBaseClass):
105    """
106    Tracks the velocity norms of the rod
107    """
108
109    def __init__(self, step_skip: int, callback_params: dict):
110        ea.CallBackBaseClass.__init__(self)
111        self.every = step_skip
112        self.callback_params = callback_params
113
114    def make_callback(self, system, time, current_step: int):
115
116        if current_step % self.every == 0:
117
118            self.callback_params["time"].append(time)
119            # Collect x
120            self.callback_params["velocity_norms"].append(
121                np.linalg.norm(system.velocity_collection.copy())
122            )
123            return
124
125
126recorded_history = ea.defaultdict(list)
127timoshenko_sim.collect_diagnostics(shearable_rod).using(
128    VelocityCallBack, step_skip=500, callback_params=recorded_history
129)
130
131timoshenko_sim.finalize()
132timestepper = ea.PositionVerlet()
133# timestepper = PEFRL()
134
135total_steps = int(final_time / dt)
136print("Total steps", total_steps)
137ea.integrate(timestepper, timoshenko_sim, final_time, total_steps)
138
139if PLOT_FIGURE:
140    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
141
142if SAVE_RESULTS:
143    import pickle
144
145    filename = "Timoshenko_beam_data.dat"
146    file = open(filename, "wb")
147    pickle.dump(shearable_rod, file)
148    file.close()
149
150    tv = (
151        np.asarray(recorded_history["time"]),
152        np.asarray(recorded_history["velocity_norms"]),
153    )
154
155    def as_time_series(v):
156        return v.T
157
158    np.savetxt(
159        "velocity_norms.csv",
160        as_time_series(np.stack(tv)),
161        delimiter=",",
162    )

Butterfly#

  1import numpy as np
  2from matplotlib import pyplot as plt
  3from matplotlib.colors import to_rgb
  4
  5
  6import elastica as ea
  7from elastica.utils import MaxDimension
  8
  9
 10class ButterflySimulator(ea.BaseSystemCollection, ea.CallBacks):
 11    pass
 12
 13
 14butterfly_sim = ButterflySimulator()
 15final_time = 40.0
 16
 17# Options
 18PLOT_FIGURE = True
 19SAVE_FIGURE = True
 20SAVE_RESULTS = True
 21ADD_UNSHEARABLE_ROD = False
 22
 23# setting up test params
 24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 25n_elem = 4  # Change based on requirements, but be careful
 26n_elem += n_elem % 2
 27half_n_elem = n_elem // 2
 28
 29origin = np.zeros((3, 1))
 30angle_of_inclination = np.deg2rad(45.0)
 31
 32# in-plane
 33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 35
 36# out-of-plane
 37normal = np.array([0.0, 1.0, 0.0])
 38
 39total_length = 3.0
 40base_radius = 0.25
 41base_area = np.pi * base_radius ** 2
 42density = 5000
 43youngs_modulus = 1e4
 44poisson_ratio = 0.5
 45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 46
 47positions = np.empty((MaxDimension.value(), n_elem + 1))
 48dl = total_length / n_elem
 49
 50# First half of positions stem from slope angle_of_inclination
 51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 53    np.cos(angle_of_inclination) * horizontal_direction
 54    + np.sin(angle_of_inclination) * vertical_direction
 55)
 56positions[..., half_n_elem:] = positions[
 57    ..., half_n_elem : half_n_elem + 1
 58] + dl * first_half * (
 59    np.cos(angle_of_inclination) * horizontal_direction
 60    - np.sin(angle_of_inclination) * vertical_direction
 61)
 62
 63butterfly_rod = ea.CosseratRod.straight_rod(
 64    n_elem,
 65    start=origin.reshape(3),
 66    direction=np.array([0.0, 0.0, 1.0]),
 67    normal=normal,
 68    base_length=total_length,
 69    base_radius=base_radius,
 70    density=density,
 71    youngs_modulus=youngs_modulus,
 72    shear_modulus=shear_modulus,
 73    position=positions,
 74)
 75
 76butterfly_sim.append(butterfly_rod)
 77
 78# Add call backs
 79class VelocityCallBack(ea.CallBackBaseClass):
 80    """
 81    Call back function for continuum snake
 82    """
 83
 84    def __init__(self, step_skip: int, callback_params: dict):
 85        ea.CallBackBaseClass.__init__(self)
 86        self.every = step_skip
 87        self.callback_params = callback_params
 88
 89    def make_callback(self, system, time, current_step: int):
 90
 91        if current_step % self.every == 0:
 92
 93            self.callback_params["time"].append(time)
 94            # Collect x
 95            self.callback_params["position"].append(system.position_collection.copy())
 96            # Collect energies as well
 97            self.callback_params["te"].append(system.compute_translational_energy())
 98            self.callback_params["re"].append(system.compute_rotational_energy())
 99            self.callback_params["se"].append(system.compute_shear_energy())
100            self.callback_params["be"].append(system.compute_bending_energy())
101            return
102
103
104recorded_history = ea.defaultdict(list)
105# initially record history
106recorded_history["time"].append(0.0)
107recorded_history["position"].append(butterfly_rod.position_collection.copy())
108recorded_history["te"].append(butterfly_rod.compute_translational_energy())
109recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
110recorded_history["se"].append(butterfly_rod.compute_shear_energy())
111recorded_history["be"].append(butterfly_rod.compute_bending_energy())
112
113butterfly_sim.collect_diagnostics(butterfly_rod).using(
114    VelocityCallBack, step_skip=100, callback_params=recorded_history
115)
116
117
118butterfly_sim.finalize()
119timestepper = ea.PositionVerlet()
120# timestepper = PEFRL()
121
122dt = 0.01 * dl
123total_steps = int(final_time / dt)
124print("Total steps", total_steps)
125ea.integrate(timestepper, butterfly_sim, final_time, total_steps)
126
127if PLOT_FIGURE:
128    # Plot the histories
129    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
130    ax = fig.add_subplot(111)
131    positions = recorded_history["position"]
132    # record first position
133    first_position = positions.pop(0)
134    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
135    n_positions = len(positions)
136    for i, pos in enumerate(positions):
137        alpha = np.exp(i / n_positions - 1)
138        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
139    # final position is also separate
140    last_position = positions.pop()
141    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
142    # don't block
143    fig.show()
144
145    # Plot the energies
146    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
147    energy_ax = energy_fig.add_subplot(111)
148    times = np.asarray(recorded_history["time"])
149    te = np.asarray(recorded_history["te"])
150    re = np.asarray(recorded_history["re"])
151    be = np.asarray(recorded_history["be"])
152    se = np.asarray(recorded_history["se"])
153
154    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
155    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
156    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
157    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
158    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
159    energy_ax.legend()
160    # don't block
161    energy_fig.show()
162
163    if SAVE_FIGURE:
164        fig.savefig("butterfly.png")
165        energy_fig.savefig("energies.png")
166
167    plt.show()
168
169if SAVE_RESULTS:
170    import pickle
171
172    filename = "butterfly_data.dat"
173    file = open(filename, "wb")
174    pickle.dump(butterfly_rod, file)
175    file.close()

Helical Buckling#

  1__doc__ = """Helical buckling validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
  7    plot_helicalbuckling,
  8)
  9
 10
 11class HelicalBucklingSimulator(
 12    ea.BaseSystemCollection, ea.Constraints, ea.Damping, ea.Forcing
 13):
 14    pass
 15
 16
 17helicalbuckling_sim = HelicalBucklingSimulator()
 18
 19# Options
 20PLOT_FIGURE = True
 21SAVE_FIGURE = True
 22SAVE_RESULTS = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 100.0
 30base_radius = 0.35
 31base_area = np.pi * base_radius ** 2
 32density = 1.0 / (base_area)
 33nu = 0.01 / density / base_area
 34E = 1e6
 35slack = 3
 36number_of_rotations = 27
 37# For shear modulus of 1e5, nu is 99!
 38poisson_ratio = 9
 39shear_modulus = E / (poisson_ratio + 1.0)
 40shear_matrix = np.repeat(
 41    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
 42)
 43temp_bend_matrix = np.zeros((3, 3))
 44np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
 45bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
 46
 47shearable_rod = ea.CosseratRod.straight_rod(
 48    n_elem,
 49    start,
 50    direction,
 51    normal,
 52    base_length,
 53    base_radius,
 54    density,
 55    youngs_modulus=E,
 56    shear_modulus=shear_modulus,
 57)
 58# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
 59
 60shearable_rod.shear_matrix = shear_matrix
 61shearable_rod.bend_matrix = bend_matrix
 62
 63
 64helicalbuckling_sim.append(shearable_rod)
 65# add damping
 66dl = base_length / n_elem
 67dt = 1e-3 * dl
 68helicalbuckling_sim.dampen(shearable_rod).using(
 69    ea.AnalyticalLinearDamper,
 70    damping_constant=nu,
 71    time_step=dt,
 72)
 73
 74helicalbuckling_sim.constrain(shearable_rod).using(
 75    ea.HelicalBucklingBC,
 76    constrained_position_idx=(0, -1),
 77    constrained_director_idx=(0, -1),
 78    twisting_time=500,
 79    slack=slack,
 80    number_of_rotations=number_of_rotations,
 81)
 82
 83helicalbuckling_sim.finalize()
 84timestepper = ea.PositionVerlet()
 85shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
 86# timestepper = PEFRL()
 87
 88final_time = 10500.0
 89total_steps = int(final_time / dt)
 90print("Total steps", total_steps)
 91ea.integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
 92
 93if PLOT_FIGURE:
 94    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
 95
 96if SAVE_RESULTS:
 97    import pickle
 98
 99    filename = "HelicalBuckling_data.dat"
100    file = open(filename, "wb")
101    pickle.dump(shearable_rod, file)
102    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import os
  4import numpy as np
  5import elastica as ea
  6
  7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
  8    plot_snake_velocity,
  9    plot_video,
 10    compute_projected_velocity,
 11    plot_curvature,
 12)
 13
 14
 15class SnakeSimulator(
 16    ea.BaseSystemCollection,
 17    ea.Constraints,
 18    ea.Forcing,
 19    ea.Damping,
 20    ea.CallBacks,
 21    ea.Contact,
 22):
 23    pass
 24
 25
 26def run_snake(
 27    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
 28):
 29    # Initialize the simulation class
 30    snake_sim = SnakeSimulator()
 31
 32    # Simulation parameters
 33    period = 2
 34    final_time = (11.0 + 0.01) * period
 35
 36    # setting up test params
 37    n_elem = 50
 38    start = np.zeros((3,))
 39    direction = np.array([0.0, 0.0, 1.0])
 40    normal = np.array([0.0, 1.0, 0.0])
 41    base_length = 0.35
 42    base_radius = base_length * 0.011
 43    density = 1000
 44    E = 1e6
 45    poisson_ratio = 0.5
 46    shear_modulus = E / (poisson_ratio + 1.0)
 47
 48    shearable_rod = ea.CosseratRod.straight_rod(
 49        n_elem,
 50        start,
 51        direction,
 52        normal,
 53        base_length,
 54        base_radius,
 55        density,
 56        youngs_modulus=E,
 57        shear_modulus=shear_modulus,
 58    )
 59
 60    snake_sim.append(shearable_rod)
 61
 62    # Add gravitational forces
 63    gravitational_acc = -9.80665
 64    snake_sim.add_forcing_to(shearable_rod).using(
 65        ea.GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 66    )
 67
 68    # Add muscle torques
 69    wave_length = b_coeff[-1]
 70    snake_sim.add_forcing_to(shearable_rod).using(
 71        ea.MuscleTorques,
 72        base_length=base_length,
 73        b_coeff=b_coeff[:-1],
 74        period=period,
 75        wave_number=2.0 * np.pi / (wave_length),
 76        phase_shift=0.0,
 77        rest_lengths=shearable_rod.rest_lengths,
 78        ramp_up_time=period,
 79        direction=normal,
 80        with_spline=True,
 81    )
 82
 83    # Add friction forces
 84    ground_plane = ea.Plane(
 85        plane_origin=np.array([0.0, -base_radius, 0.0]), plane_normal=normal
 86    )
 87    snake_sim.append(ground_plane)
 88    slip_velocity_tol = 1e-8
 89    froude = 0.1
 90    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 91    kinetic_mu_array = np.array(
 92        [mu, 1.5 * mu, 2.0 * mu]
 93    )  # [forward, backward, sideways]
 94    static_mu_array = np.zeros(kinetic_mu_array.shape)
 95    snake_sim.detect_contact_between(shearable_rod, ground_plane).using(
 96        ea.RodPlaneContactWithAnisotropicFriction,
 97        k=1.0,
 98        nu=1e-6,
 99        slip_velocity_tol=slip_velocity_tol,
100        static_mu_array=static_mu_array,
101        kinetic_mu_array=kinetic_mu_array,
102    )
103
104    # add damping
105    damping_constant = 2e-3
106    time_step = 1e-4
107    snake_sim.dampen(shearable_rod).using(
108        ea.AnalyticalLinearDamper,
109        damping_constant=damping_constant,
110        time_step=time_step,
111    )
112
113    total_steps = int(final_time / time_step)
114    rendering_fps = 60
115    step_skip = int(1.0 / (rendering_fps * time_step))
116
117    # Add call backs
118    class ContinuumSnakeCallBack(ea.CallBackBaseClass):
119        """
120        Call back function for continuum snake
121        """
122
123        def __init__(self, step_skip: int, callback_params: dict):
124            ea.CallBackBaseClass.__init__(self)
125            self.every = step_skip
126            self.callback_params = callback_params
127
128        def make_callback(self, system, time, current_step: int):
129
130            if current_step % self.every == 0:
131
132                self.callback_params["time"].append(time)
133                self.callback_params["step"].append(current_step)
134                self.callback_params["position"].append(
135                    system.position_collection.copy()
136                )
137                self.callback_params["velocity"].append(
138                    system.velocity_collection.copy()
139                )
140                self.callback_params["avg_velocity"].append(
141                    system.compute_velocity_center_of_mass()
142                )
143
144                self.callback_params["center_of_mass"].append(
145                    system.compute_position_center_of_mass()
146                )
147                self.callback_params["curvature"].append(system.kappa.copy())
148
149                return
150
151    pp_list = ea.defaultdict(list)
152    snake_sim.collect_diagnostics(shearable_rod).using(
153        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
154    )
155
156    snake_sim.finalize()
157
158    timestepper = ea.PositionVerlet()
159    ea.integrate(timestepper, snake_sim, final_time, total_steps)
160
161    if PLOT_FIGURE:
162        filename_plot = "continuum_snake_velocity.png"
163        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
164        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
165
166        if SAVE_VIDEO:
167            filename_video = "continuum_snake.mp4"
168            plot_video(
169                pp_list,
170                video_name=filename_video,
171                fps=rendering_fps,
172                xlim=(0, 4),
173                ylim=(-1, 1),
174            )
175
176    if SAVE_RESULTS:
177        import pickle
178
179        filename = "continuum_snake.dat"
180        file = open(filename, "wb")
181        pickle.dump(pp_list, file)
182        file.close()
183
184    # Compute the average forward velocity. These will be used for optimization.
185    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
186
187    return avg_forward, avg_lateral, pp_list
188
189
190if __name__ == "__main__":
191
192    # Options
193    PLOT_FIGURE = True
194    SAVE_FIGURE = True
195    SAVE_VIDEO = True
196    SAVE_RESULTS = False
197    CMA_OPTION = False
198
199    if CMA_OPTION:
200        import cma
201
202        SAVE_OPTIMIZED_COEFFICIENTS = False
203
204        def optimize_snake(spline_coefficient):
205            [avg_forward, _, _] = run_snake(
206                spline_coefficient,
207                PLOT_FIGURE=False,
208                SAVE_FIGURE=False,
209                SAVE_VIDEO=False,
210                SAVE_RESULTS=False,
211            )
212            return -avg_forward
213
214        # Optimize snake for forward velocity. In cma.fmin first input is function
215        # to be optimized, second input is initial guess for coefficients you are optimizing
216        # for and third input is standard deviation you initially set.
217        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
218
219        # Save the optimized coefficients to a file
220        filename_data = "optimized_coefficients.txt"
221        if SAVE_OPTIMIZED_COEFFICIENTS:
222            assert filename_data != "", "provide a file name for coefficients"
223            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
224
225    else:
226        # Add muscle forces on the rod
227        if os.path.exists("optimized_coefficients.txt"):
228            t_coeff_optimized = np.genfromtxt(
229                "optimized_coefficients.txt", delimiter=","
230            )
231        else:
232            wave_length = 1.0
233            t_coeff_optimized = np.array(
234                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
235            )
236            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
237
238        # run the simulation
239        [avg_forward, avg_lateral, pp_list] = run_snake(
240            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
241        )
242
243        print("average forward velocity:", avg_forward)
244        print("average forward lateral:", avg_lateral)