Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1""" Axial stretching test-case
  2
  3    Assume we have a rod lying aligned in the x-direction, with high internal
  4    damping.
  5
  6    We fix one end (say, the left end) of the rod to a wall. On the right
  7    end we apply a force directed axially pulling the rods tip. Linear
  8    theory (assuming small displacements) predict that the net displacement
  9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10    usual meaning (the rod is just a linear spring). We compare our results
 11    with the above result.
 12
 13    We can "improve" the theory by having a better estimate for the rod's
 14    spring constant by assuming that it equilibriates under the new position,
 15    with
 16    Δx = F * (L + Δx)/ (A * E)
 17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18    this position.
 19
 20    Note that if the damping is not high, the rod oscillates about the eventual
 21    resting position (and this agrees with the theoretical predictions without
 22    any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24    isort:skip_file
 25"""
 26import numpy as np
 27from matplotlib import pyplot as plt
 28
 29from elastica import *
 30
 31
 32class StretchingBeamSimulator(
 33    BaseSystemCollection, Constraints, Forcing, Damping, CallBacks
 34):
 35    pass
 36
 37
 38stretch_sim = StretchingBeamSimulator()
 39final_time = 200.0
 40
 41# Options
 42PLOT_FIGURE = True
 43SAVE_FIGURE = False
 44SAVE_RESULTS = False
 45
 46# setting up test params
 47n_elem = 19
 48start = np.zeros((3,))
 49direction = np.array([1.0, 0.0, 0.0])
 50normal = np.array([0.0, 1.0, 0.0])
 51base_length = 1.0
 52base_radius = 0.025
 53base_area = np.pi * base_radius ** 2
 54density = 1000
 55youngs_modulus = 1e4
 56# For shear modulus of 1e4, nu is 99!
 57poisson_ratio = 0.5
 58shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 59
 60stretchable_rod = CosseratRod.straight_rod(
 61    n_elem,
 62    start,
 63    direction,
 64    normal,
 65    base_length,
 66    base_radius,
 67    density,
 68    0.0,  # internal damping constant, deprecated in v0.3.0
 69    youngs_modulus,
 70    shear_modulus=shear_modulus,
 71)
 72
 73stretch_sim.append(stretchable_rod)
 74stretch_sim.constrain(stretchable_rod).using(
 75    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 76)
 77
 78end_force_x = 1.0
 79end_force = np.array([end_force_x, 0.0, 0.0])
 80stretch_sim.add_forcing_to(stretchable_rod).using(
 81    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 82)
 83
 84# add damping
 85dl = base_length / n_elem
 86# old damping model (deprecated in v0.3.0) values
 87# dt = 0.01 * dl
 88# damping_constant = 1.0
 89dt = 0.1 * dl
 90damping_constant = 0.1
 91stretch_sim.dampen(stretchable_rod).using(
 92    AnalyticalLinearDamper,
 93    damping_constant=damping_constant,
 94    time_step=dt,
 95)
 96
 97# Add call backs
 98class AxialStretchingCallBack(CallBackBaseClass):
 99    """
100    Tracks the velocity norms of the rod
101    """
102
103    def __init__(self, step_skip: int, callback_params: dict):
104        CallBackBaseClass.__init__(self)
105        self.every = step_skip
106        self.callback_params = callback_params
107
108    def make_callback(self, system, time, current_step: int):
109
110        if current_step % self.every == 0:
111
112            self.callback_params["time"].append(time)
113            # Collect only x
114            self.callback_params["position"].append(
115                system.position_collection[0, -1].copy()
116            )
117            self.callback_params["velocity_norms"].append(
118                np.linalg.norm(system.velocity_collection.copy())
119            )
120            return
121
122
123recorded_history = defaultdict(list)
124stretch_sim.collect_diagnostics(stretchable_rod).using(
125    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
126)
127
128stretch_sim.finalize()
129timestepper = PositionVerlet()
130# timestepper = PEFRL()
131
132total_steps = int(final_time / dt)
133print("Total steps", total_steps)
134integrate(timestepper, stretch_sim, final_time, total_steps)
135
136if PLOT_FIGURE:
137    # First-order theory with base-length
138    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
139    # First-order theory with modified-length, gives better estimates
140    expected_tip_disp_improved = (
141        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
142    )
143
144    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
145    ax = fig.add_subplot(111)
146    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
147    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
148    ax.hlines(
149        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
150    )
151    if SAVE_FIGURE:
152        fig.savefig("axial_stretching.pdf")
153    plt.show()
154
155if SAVE_RESULTS:
156    import pickle
157
158    filename = "axial_stretching_data.dat"
159    file = open(filename, "wb")
160    pickle.dump(stretchable_rod, file)
161    file.close()
162
163    tv = (
164        np.asarray(recorded_history["time"]),
165        np.asarray(recorded_history["velocity_norms"]),
166    )
167
168    def as_time_series(v):
169        return v.T
170
171    np.savetxt(
172        "velocity_norms.csv",
173        as_time_series(np.stack(tv)),
174        delimiter=",",
175    )

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5from elastica import *
  6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
  7
  8
  9class TimoshenkoBeamSimulator(
 10    BaseSystemCollection, Constraints, Forcing, CallBacks, Damping
 11):
 12    pass
 13
 14
 15timoshenko_sim = TimoshenkoBeamSimulator()
 16final_time = 5000.0
 17
 18# Options
 19PLOT_FIGURE = True
 20SAVE_FIGURE = True
 21SAVE_RESULTS = False
 22ADD_UNSHEARABLE_ROD = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 3.0
 30base_radius = 0.25
 31base_area = np.pi * base_radius ** 2
 32density = 5000
 33nu = 0.1 / 7 / density / base_area
 34E = 1e6
 35# For shear modulus of 1e4, nu is 99!
 36poisson_ratio = 99
 37shear_modulus = E / (poisson_ratio + 1.0)
 38
 39shearable_rod = CosseratRod.straight_rod(
 40    n_elem,
 41    start,
 42    direction,
 43    normal,
 44    base_length,
 45    base_radius,
 46    density,
 47    0.0,  # internal damping constant, deprecated in v0.3.0
 48    E,
 49    shear_modulus=shear_modulus,
 50)
 51
 52timoshenko_sim.append(shearable_rod)
 53# add damping
 54dl = base_length / n_elem
 55dt = 0.07 * dl
 56timoshenko_sim.dampen(shearable_rod).using(
 57    AnalyticalLinearDamper,
 58    damping_constant=nu,
 59    time_step=dt,
 60)
 61
 62timoshenko_sim.constrain(shearable_rod).using(
 63    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 64)
 65
 66end_force = np.array([-15.0, 0.0, 0.0])
 67timoshenko_sim.add_forcing_to(shearable_rod).using(
 68    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 69)
 70
 71
 72if ADD_UNSHEARABLE_ROD:
 73    # Start into the plane
 74    unshearable_start = np.array([0.0, -1.0, 0.0])
 75    shear_modulus = E / (-0.7 + 1.0)
 76    unshearable_rod = CosseratRod.straight_rod(
 77        n_elem,
 78        unshearable_start,
 79        direction,
 80        normal,
 81        base_length,
 82        base_radius,
 83        density,
 84        0.0,  # internal damping constant, deprecated in v0.3.0
 85        E,
 86        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 87        shear_modulus=shear_modulus,
 88    )
 89
 90    timoshenko_sim.append(unshearable_rod)
 91
 92    # add damping
 93    timoshenko_sim.dampen(unshearable_rod).using(
 94        AnalyticalLinearDamper,
 95        damping_constant=nu,
 96        time_step=dt,
 97    )
 98    timoshenko_sim.constrain(unshearable_rod).using(
 99        OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
100    )
101    timoshenko_sim.add_forcing_to(unshearable_rod).using(
102        EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
103    )
104
105# Add call backs
106class VelocityCallBack(CallBackBaseClass):
107    """
108    Tracks the velocity norms of the rod
109    """
110
111    def __init__(self, step_skip: int, callback_params: dict):
112        CallBackBaseClass.__init__(self)
113        self.every = step_skip
114        self.callback_params = callback_params
115
116    def make_callback(self, system, time, current_step: int):
117
118        if current_step % self.every == 0:
119
120            self.callback_params["time"].append(time)
121            # Collect x
122            self.callback_params["velocity_norms"].append(
123                np.linalg.norm(system.velocity_collection.copy())
124            )
125            return
126
127
128recorded_history = defaultdict(list)
129timoshenko_sim.collect_diagnostics(shearable_rod).using(
130    VelocityCallBack, step_skip=500, callback_params=recorded_history
131)
132
133timoshenko_sim.finalize()
134timestepper = PositionVerlet()
135# timestepper = PEFRL()
136
137total_steps = int(final_time / dt)
138print("Total steps", total_steps)
139integrate(timestepper, timoshenko_sim, final_time, total_steps)
140
141if PLOT_FIGURE:
142    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
143
144if SAVE_RESULTS:
145    import pickle
146
147    filename = "Timoshenko_beam_data.dat"
148    file = open(filename, "wb")
149    pickle.dump(shearable_rod, file)
150    file.close()
151
152    tv = (
153        np.asarray(recorded_history["time"]),
154        np.asarray(recorded_history["velocity_norms"]),
155    )
156
157    def as_time_series(v):
158        return v.T
159
160    np.savetxt(
161        "velocity_norms.csv",
162        as_time_series(np.stack(tv)),
163        delimiter=",",
164    )

Butterfly#

  1import numpy as np
  2from matplotlib import pyplot as plt
  3from matplotlib.colors import to_rgb
  4
  5
  6from elastica import *
  7from elastica.utils import MaxDimension
  8
  9
 10class ButterflySimulator(BaseSystemCollection, CallBacks):
 11    pass
 12
 13
 14butterfly_sim = ButterflySimulator()
 15final_time = 40.0
 16
 17# Options
 18PLOT_FIGURE = True
 19SAVE_FIGURE = True
 20SAVE_RESULTS = True
 21ADD_UNSHEARABLE_ROD = False
 22
 23# setting up test params
 24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 25n_elem = 4  # Change based on requirements, but be careful
 26n_elem += n_elem % 2
 27half_n_elem = n_elem // 2
 28
 29origin = np.zeros((3, 1))
 30angle_of_inclination = np.deg2rad(45.0)
 31
 32# in-plane
 33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 35
 36# out-of-plane
 37normal = np.array([0.0, 1.0, 0.0])
 38
 39total_length = 3.0
 40base_radius = 0.25
 41base_area = np.pi * base_radius ** 2
 42density = 5000
 43youngs_modulus = 1e4
 44poisson_ratio = 0.5
 45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 46
 47positions = np.empty((MaxDimension.value(), n_elem + 1))
 48dl = total_length / n_elem
 49
 50# First half of positions stem from slope angle_of_inclination
 51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 53    np.cos(angle_of_inclination) * horizontal_direction
 54    + np.sin(angle_of_inclination) * vertical_direction
 55)
 56positions[..., half_n_elem:] = positions[
 57    ..., half_n_elem : half_n_elem + 1
 58] + dl * first_half * (
 59    np.cos(angle_of_inclination) * horizontal_direction
 60    - np.sin(angle_of_inclination) * vertical_direction
 61)
 62
 63butterfly_rod = CosseratRod.straight_rod(
 64    n_elem,
 65    start=origin.reshape(3),
 66    direction=np.array([0.0, 0.0, 1.0]),
 67    normal=normal,
 68    base_length=total_length,
 69    base_radius=base_radius,
 70    density=density,
 71    nu=0.0,  # internal damping constant, deprecated in v0.3.0
 72    youngs_modulus=youngs_modulus,
 73    shear_modulus=shear_modulus,
 74    position=positions,
 75)
 76
 77butterfly_sim.append(butterfly_rod)
 78
 79# Add call backs
 80class VelocityCallBack(CallBackBaseClass):
 81    """
 82    Call back function for continuum snake
 83    """
 84
 85    def __init__(self, step_skip: int, callback_params: dict):
 86        CallBackBaseClass.__init__(self)
 87        self.every = step_skip
 88        self.callback_params = callback_params
 89
 90    def make_callback(self, system, time, current_step: int):
 91
 92        if current_step % self.every == 0:
 93
 94            self.callback_params["time"].append(time)
 95            # Collect x
 96            self.callback_params["position"].append(system.position_collection.copy())
 97            # Collect energies as well
 98            self.callback_params["te"].append(system.compute_translational_energy())
 99            self.callback_params["re"].append(system.compute_rotational_energy())
100            self.callback_params["se"].append(system.compute_shear_energy())
101            self.callback_params["be"].append(system.compute_bending_energy())
102            return
103
104
105recorded_history = defaultdict(list)
106# initially record history
107recorded_history["time"].append(0.0)
108recorded_history["position"].append(butterfly_rod.position_collection.copy())
109recorded_history["te"].append(butterfly_rod.compute_translational_energy())
110recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
111recorded_history["se"].append(butterfly_rod.compute_shear_energy())
112recorded_history["be"].append(butterfly_rod.compute_bending_energy())
113
114butterfly_sim.collect_diagnostics(butterfly_rod).using(
115    VelocityCallBack, step_skip=100, callback_params=recorded_history
116)
117
118
119butterfly_sim.finalize()
120timestepper = PositionVerlet()
121# timestepper = PEFRL()
122
123dt = 0.01 * dl
124total_steps = int(final_time / dt)
125print("Total steps", total_steps)
126integrate(timestepper, butterfly_sim, final_time, total_steps)
127
128if PLOT_FIGURE:
129    # Plot the histories
130    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
131    ax = fig.add_subplot(111)
132    positions = recorded_history["position"]
133    # record first position
134    first_position = positions.pop(0)
135    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
136    n_positions = len(positions)
137    for i, pos in enumerate(positions):
138        alpha = np.exp(i / n_positions - 1)
139        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
140    # final position is also separate
141    last_position = positions.pop()
142    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
143    # don't block
144    fig.show()
145
146    # Plot the energies
147    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
148    energy_ax = energy_fig.add_subplot(111)
149    times = np.asarray(recorded_history["time"])
150    te = np.asarray(recorded_history["te"])
151    re = np.asarray(recorded_history["re"])
152    be = np.asarray(recorded_history["be"])
153    se = np.asarray(recorded_history["se"])
154
155    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
156    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
157    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
158    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
159    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
160    energy_ax.legend()
161    # don't block
162    energy_fig.show()
163
164    if SAVE_FIGURE:
165        fig.savefig("butterfly.png")
166        energy_fig.savefig("energies.png")
167
168    plt.show()
169
170if SAVE_RESULTS:
171    import pickle
172
173    filename = "butterfly_data.dat"
174    file = open(filename, "wb")
175    pickle.dump(butterfly_rod, file)
176    file.close()

Helical Buckling#

  1__doc__ = """Helical buckling validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
  3
  4import numpy as np
  5from elastica import *
  6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
  7    plot_helicalbuckling,
  8)
  9
 10
 11class HelicalBucklingSimulator(BaseSystemCollection, Constraints, Damping, Forcing):
 12    pass
 13
 14
 15helicalbuckling_sim = HelicalBucklingSimulator()
 16
 17# Options
 18PLOT_FIGURE = True
 19SAVE_FIGURE = True
 20SAVE_RESULTS = False
 21
 22# setting up test params
 23n_elem = 100
 24start = np.zeros((3,))
 25direction = np.array([0.0, 0.0, 1.0])
 26normal = np.array([0.0, 1.0, 0.0])
 27base_length = 100.0
 28base_radius = 0.35
 29base_area = np.pi * base_radius ** 2
 30density = 1.0 / (base_area)
 31nu = 0.01 / density / base_area
 32E = 1e6
 33slack = 3
 34number_of_rotations = 27
 35# For shear modulus of 1e5, nu is 99!
 36poisson_ratio = 9
 37shear_modulus = E / (poisson_ratio + 1.0)
 38shear_matrix = np.repeat(
 39    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
 40)
 41temp_bend_matrix = np.zeros((3, 3))
 42np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
 43bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
 44
 45shearable_rod = CosseratRod.straight_rod(
 46    n_elem,
 47    start,
 48    direction,
 49    normal,
 50    base_length,
 51    base_radius,
 52    density,
 53    0.0,  # internal damping constant, deprecated in v0.3.0
 54    E,
 55    shear_modulus=shear_modulus,
 56)
 57# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
 58
 59shearable_rod.shear_matrix = shear_matrix
 60shearable_rod.bend_matrix = bend_matrix
 61
 62
 63helicalbuckling_sim.append(shearable_rod)
 64# add damping
 65dl = base_length / n_elem
 66dt = 1e-3 * dl
 67helicalbuckling_sim.dampen(shearable_rod).using(
 68    AnalyticalLinearDamper,
 69    damping_constant=nu,
 70    time_step=dt,
 71)
 72
 73helicalbuckling_sim.constrain(shearable_rod).using(
 74    HelicalBucklingBC,
 75    constrained_position_idx=(0, -1),
 76    constrained_director_idx=(0, -1),
 77    twisting_time=500,
 78    slack=slack,
 79    number_of_rotations=number_of_rotations,
 80)
 81
 82helicalbuckling_sim.finalize()
 83timestepper = PositionVerlet()
 84shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
 85# timestepper = PEFRL()
 86
 87final_time = 10500.0
 88total_steps = int(final_time / dt)
 89print("Total steps", total_steps)
 90integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
 91
 92if PLOT_FIGURE:
 93    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
 94
 95if SAVE_RESULTS:
 96    import pickle
 97
 98    filename = "HelicalBuckling_data.dat"
 99    file = open(filename, "wb")
100    pickle.dump(shearable_rod, file)
101    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import os
  4import numpy as np
  5from elastica import *
  6
  7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
  8    plot_snake_velocity,
  9    plot_video,
 10    compute_projected_velocity,
 11    plot_curvature,
 12)
 13
 14
 15class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, Damping, CallBacks):
 16    pass
 17
 18
 19def run_snake(
 20    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
 21):
 22    # Initialize the simulation class
 23    snake_sim = SnakeSimulator()
 24
 25    # Simulation parameters
 26    period = 2
 27    final_time = (11.0 + 0.01) * period
 28
 29    # setting up test params
 30    n_elem = 50
 31    start = np.zeros((3,))
 32    direction = np.array([0.0, 0.0, 1.0])
 33    normal = np.array([0.0, 1.0, 0.0])
 34    base_length = 0.35
 35    base_radius = base_length * 0.011
 36    density = 1000
 37    E = 1e6
 38    poisson_ratio = 0.5
 39    shear_modulus = E / (poisson_ratio + 1.0)
 40
 41    shearable_rod = CosseratRod.straight_rod(
 42        n_elem,
 43        start,
 44        direction,
 45        normal,
 46        base_length,
 47        base_radius,
 48        density,
 49        0.0,  # internal damping constant, deprecated in v0.3.0
 50        E,
 51        shear_modulus=shear_modulus,
 52    )
 53
 54    snake_sim.append(shearable_rod)
 55
 56    # Add gravitational forces
 57    gravitational_acc = -9.80665
 58    snake_sim.add_forcing_to(shearable_rod).using(
 59        GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 60    )
 61
 62    # Add muscle torques
 63    wave_length = b_coeff[-1]
 64    snake_sim.add_forcing_to(shearable_rod).using(
 65        MuscleTorques,
 66        base_length=base_length,
 67        b_coeff=b_coeff[:-1],
 68        period=period,
 69        wave_number=2.0 * np.pi / (wave_length),
 70        phase_shift=0.0,
 71        rest_lengths=shearable_rod.rest_lengths,
 72        ramp_up_time=period,
 73        direction=normal,
 74        with_spline=True,
 75    )
 76
 77    # Add friction forces
 78    origin_plane = np.array([0.0, -base_radius, 0.0])
 79    normal_plane = normal
 80    slip_velocity_tol = 1e-8
 81    froude = 0.1
 82    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 83    kinetic_mu_array = np.array(
 84        [mu, 1.5 * mu, 2.0 * mu]
 85    )  # [forward, backward, sideways]
 86    static_mu_array = np.zeros(kinetic_mu_array.shape)
 87    snake_sim.add_forcing_to(shearable_rod).using(
 88        AnisotropicFrictionalPlane,
 89        k=1.0,
 90        nu=1e-6,
 91        plane_origin=origin_plane,
 92        plane_normal=normal_plane,
 93        slip_velocity_tol=slip_velocity_tol,
 94        static_mu_array=static_mu_array,
 95        kinetic_mu_array=kinetic_mu_array,
 96    )
 97
 98    # add damping
 99    # old damping model (deprecated in v0.3.0) values
100    # damping_constant = 2e-3
101    # time_step = 8e-6
102    damping_constant = 2e-3
103    time_step = 1e-4
104    snake_sim.dampen(shearable_rod).using(
105        AnalyticalLinearDamper,
106        damping_constant=damping_constant,
107        time_step=time_step,
108    )
109
110    total_steps = int(final_time / time_step)
111    rendering_fps = 60
112    step_skip = int(1.0 / (rendering_fps * time_step))
113
114    # Add call backs
115    class ContinuumSnakeCallBack(CallBackBaseClass):
116        """
117        Call back function for continuum snake
118        """
119
120        def __init__(self, step_skip: int, callback_params: dict):
121            CallBackBaseClass.__init__(self)
122            self.every = step_skip
123            self.callback_params = callback_params
124
125        def make_callback(self, system, time, current_step: int):
126
127            if current_step % self.every == 0:
128
129                self.callback_params["time"].append(time)
130                self.callback_params["step"].append(current_step)
131                self.callback_params["position"].append(
132                    system.position_collection.copy()
133                )
134                self.callback_params["velocity"].append(
135                    system.velocity_collection.copy()
136                )
137                self.callback_params["avg_velocity"].append(
138                    system.compute_velocity_center_of_mass()
139                )
140
141                self.callback_params["center_of_mass"].append(
142                    system.compute_position_center_of_mass()
143                )
144                self.callback_params["curvature"].append(system.kappa.copy())
145
146                return
147
148    pp_list = defaultdict(list)
149    snake_sim.collect_diagnostics(shearable_rod).using(
150        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
151    )
152
153    snake_sim.finalize()
154
155    timestepper = PositionVerlet()
156    integrate(timestepper, snake_sim, final_time, total_steps)
157
158    if PLOT_FIGURE:
159        filename_plot = "continuum_snake_velocity.png"
160        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
161        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
162
163        if SAVE_VIDEO:
164            filename_video = "continuum_snake.mp4"
165            plot_video(
166                pp_list,
167                video_name=filename_video,
168                fps=rendering_fps,
169                xlim=(0, 4),
170                ylim=(-1, 1),
171            )
172
173    if SAVE_RESULTS:
174        import pickle
175
176        filename = "continuum_snake.dat"
177        file = open(filename, "wb")
178        pickle.dump(pp_list, file)
179        file.close()
180
181    # Compute the average forward velocity. These will be used for optimization.
182    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
183
184    return avg_forward, avg_lateral, pp_list
185
186
187if __name__ == "__main__":
188
189    # Options
190    PLOT_FIGURE = True
191    SAVE_FIGURE = True
192    SAVE_VIDEO = True
193    SAVE_RESULTS = False
194    CMA_OPTION = False
195
196    if CMA_OPTION:
197        import cma
198
199        SAVE_OPTIMIZED_COEFFICIENTS = False
200
201        def optimize_snake(spline_coefficient):
202            [avg_forward, _, _] = run_snake(
203                spline_coefficient,
204                PLOT_FIGURE=False,
205                SAVE_FIGURE=False,
206                SAVE_VIDEO=False,
207                SAVE_RESULTS=False,
208            )
209            return -avg_forward
210
211        # Optimize snake for forward velocity. In cma.fmin first input is function
212        # to be optimized, second input is initial guess for coefficients you are optimizing
213        # for and third input is standard deviation you initially set.
214        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
215
216        # Save the optimized coefficients to a file
217        filename_data = "optimized_coefficients.txt"
218        if SAVE_OPTIMIZED_COEFFICIENTS:
219            assert filename_data != "", "provide a file name for coefficients"
220            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
221
222    else:
223        # Add muscle forces on the rod
224        if os.path.exists("optimized_coefficients.txt"):
225            t_coeff_optimized = np.genfromtxt(
226                "optimized_coefficients.txt", delimiter=","
227            )
228        else:
229            wave_length = 1.0
230            t_coeff_optimized = np.array(
231                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
232            )
233            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
234
235        # run the simulation
236        [avg_forward, avg_lateral, pp_list] = run_snake(
237            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
238        )
239
240        print("average forward velocity:", avg_forward)
241        print("average forward lateral:", avg_lateral)