Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1""" Axial stretching test-case
  2
  3    Assume we have a rod lying aligned in the x-direction, with high internal
  4    damping.
  5
  6    We fix one end (say, the left end) of the rod to a wall. On the right
  7    end we apply a force directed axially pulling the rods tip. Linear
  8    theory (assuming small displacements) predict that the net displacement
  9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10    usual meaning (the rod is just a linear spring). We compare our results
 11    with the above result.
 12
 13    We can "improve" the theory by having a better estimate for the rod's
 14    spring constant by assuming that it equilibriates under the new position,
 15    with
 16    Δx = F * (L + Δx)/ (A * E)
 17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18    this position.
 19
 20    Note that if the damping is not high, the rod oscillates about the eventual
 21    resting position (and this agrees with the theoretical predictions without
 22    any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24    isort:skip_file
 25"""
 26# FIXME without appending sys.path make it more generic
 27import sys
 28
 29sys.path.append("../../")  # isort:skip
 30
 31# from collections import defaultdict
 32
 33import numpy as np
 34from matplotlib import pyplot as plt
 35
 36from elastica import *
 37
 38
 39class StretchingBeamSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
 40    pass
 41
 42
 43stretch_sim = StretchingBeamSimulator()
 44final_time = 20.0
 45
 46# Options
 47PLOT_FIGURE = True
 48SAVE_FIGURE = False
 49SAVE_RESULTS = False
 50
 51# setting up test params
 52n_elem = 19
 53start = np.zeros((3,))
 54direction = np.array([1.0, 0.0, 0.0])
 55normal = np.array([0.0, 1.0, 0.0])
 56base_length = 1.0
 57base_radius = 0.025
 58base_area = np.pi * base_radius ** 2
 59density = 1000
 60nu = 2.0
 61youngs_modulus = 1e4
 62# For shear modulus of 1e4, nu is 99!
 63poisson_ratio = 0.5
 64shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 65
 66stretchable_rod = CosseratRod.straight_rod(
 67    n_elem,
 68    start,
 69    direction,
 70    normal,
 71    base_length,
 72    base_radius,
 73    density,
 74    nu,
 75    youngs_modulus,
 76    shear_modulus=shear_modulus,
 77)
 78
 79stretch_sim.append(stretchable_rod)
 80stretch_sim.constrain(stretchable_rod).using(
 81    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 82)
 83
 84end_force_x = 1.0
 85end_force = np.array([end_force_x, 0.0, 0.0])
 86stretch_sim.add_forcing_to(stretchable_rod).using(
 87    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 88)
 89
 90# Add call backs
 91class AxialStretchingCallBack(CallBackBaseClass):
 92    """
 93    Call back function for continuum snake
 94    """
 95
 96    def __init__(self, step_skip: int, callback_params: dict):
 97        CallBackBaseClass.__init__(self)
 98        self.every = step_skip
 99        self.callback_params = callback_params
100
101    def make_callback(self, system, time, current_step: int):
102
103        if current_step % self.every == 0:
104
105            self.callback_params["time"].append(time)
106            # Collect only x
107            self.callback_params["position"].append(
108                system.position_collection[0, -1].copy()
109            )
110            return
111
112
113recorded_history = defaultdict(list)
114stretch_sim.collect_diagnostics(stretchable_rod).using(
115    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
116)
117
118stretch_sim.finalize()
119timestepper = PositionVerlet()
120# timestepper = PEFRL()
121
122dl = base_length / n_elem
123dt = 0.01 * dl
124total_steps = int(final_time / dt)
125print("Total steps", total_steps)
126integrate(timestepper, stretch_sim, final_time, total_steps)
127
128if PLOT_FIGURE:
129    # First-order theory with base-length
130    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
131    # First-order theory with modified-length, gives better estimates
132    expected_tip_disp_improved = (
133        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
134    )
135
136    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
137    ax = fig.add_subplot(111)
138    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
139    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
140    ax.hlines(
141        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
142    )
143    if SAVE_FIGURE:
144        fig.savefig("axial_stretching.pdf")
145    plt.show()
146
147if SAVE_RESULTS:
148    import pickle
149
150    filename = "axial_stretching_data.dat"
151    file = open(filename, "wb")
152    pickle.dump(stretchable_rod, file)
153    file.close()

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to 
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5import sys
  6
  7# FIXME without appending sys.path make it more generic
  8sys.path.append("../../")
  9from elastica import *
 10from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
 11
 12
 13class TimoshenkoBeamSimulator(BaseSystemCollection, Constraints, Forcing):
 14    pass
 15
 16
 17timoshenko_sim = TimoshenkoBeamSimulator()
 18final_time = 5000
 19
 20# Options
 21PLOT_FIGURE = True
 22SAVE_FIGURE = False
 23SAVE_RESULTS = False
 24ADD_UNSHEARABLE_ROD = True
 25
 26# setting up test params
 27n_elem = 100
 28start = np.zeros((3,))
 29direction = np.array([0.0, 0.0, 1.0])
 30normal = np.array([0.0, 1.0, 0.0])
 31base_length = 3.0
 32base_radius = 0.25
 33base_area = np.pi * base_radius ** 2
 34density = 5000
 35nu = 0.1
 36E = 1e6
 37# For shear modulus of 1e4, nu is 99!
 38poisson_ratio = 99
 39shear_modulus = E / (poisson_ratio + 1.0)
 40
 41shearable_rod = CosseratRod.straight_rod(
 42    n_elem,
 43    start,
 44    direction,
 45    normal,
 46    base_length,
 47    base_radius,
 48    density,
 49    nu,
 50    E,
 51    shear_modulus=shear_modulus,
 52)
 53
 54timoshenko_sim.append(shearable_rod)
 55timoshenko_sim.constrain(shearable_rod).using(
 56    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 57)
 58
 59end_force = np.array([-15.0, 0.0, 0.0])
 60timoshenko_sim.add_forcing_to(shearable_rod).using(
 61    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 62)
 63
 64
 65if ADD_UNSHEARABLE_ROD:
 66    # Start into the plane
 67    unshearable_start = np.array([0.0, -1.0, 0.0])
 68    shear_modulus = E / (-0.7 + 1.0)
 69    unshearable_rod = CosseratRod.straight_rod(
 70        n_elem,
 71        unshearable_start,
 72        direction,
 73        normal,
 74        base_length,
 75        base_radius,
 76        density,
 77        nu,
 78        E,
 79        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 80        shear_modulus=shear_modulus,
 81    )
 82
 83    timoshenko_sim.append(unshearable_rod)
 84    timoshenko_sim.constrain(unshearable_rod).using(
 85        OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 86    )
 87    timoshenko_sim.add_forcing_to(unshearable_rod).using(
 88        EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 89    )
 90
 91timoshenko_sim.finalize()
 92timestepper = PositionVerlet()
 93# timestepper = PEFRL()
 94
 95dl = base_length / n_elem
 96dt = 0.01 * dl
 97total_steps = int(final_time / dt)
 98print("Total steps", total_steps)
 99integrate(timestepper, timoshenko_sim, final_time, total_steps)
100
101if PLOT_FIGURE:
102    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
103
104if SAVE_RESULTS:
105    import pickle
106
107    filename = "Timoshenko_beam_data.dat"
108    file = open(filename, "wb")
109    pickle.dump(shearable_rod, file)
110    file.close()

Butterfly#

  1# FIXME without appending sys.path make it more generic
  2import sys
  3
  4sys.path.append("../")
  5sys.path.append("../../")
  6
  7# from collections import defaultdict
  8import numpy as np
  9from matplotlib import pyplot as plt
 10from matplotlib.colors import to_rgb
 11
 12
 13from elastica import *
 14from elastica.utils import MaxDimension
 15
 16
 17class ButterflySimulator(BaseSystemCollection, CallBacks):
 18    pass
 19
 20
 21butterfly_sim = ButterflySimulator()
 22final_time = 40.0
 23
 24# Options
 25PLOT_FIGURE = True
 26SAVE_FIGURE = True
 27SAVE_RESULTS = True
 28ADD_UNSHEARABLE_ROD = False
 29
 30# setting up test params
 31# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 32n_elem = 4  # Change based on requirements, but be careful
 33n_elem += n_elem % 2
 34half_n_elem = n_elem // 2
 35
 36origin = np.zeros((3, 1))
 37angle_of_inclination = np.deg2rad(45.0)
 38
 39# in-plane
 40horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 41vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 42
 43# out-of-plane
 44normal = np.array([0.0, 1.0, 0.0])
 45
 46total_length = 3.0
 47base_radius = 0.25
 48base_area = np.pi * base_radius ** 2
 49density = 5000
 50nu = 0.0
 51youngs_modulus = 1e4
 52poisson_ratio = 0.5
 53shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 54
 55positions = np.empty((MaxDimension.value(), n_elem + 1))
 56dl = total_length / n_elem
 57
 58# First half of positions stem from slope angle_of_inclination
 59first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 60positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 61    np.cos(angle_of_inclination) * horizontal_direction
 62    + np.sin(angle_of_inclination) * vertical_direction
 63)
 64positions[..., half_n_elem:] = positions[
 65    ..., half_n_elem : half_n_elem + 1
 66] + dl * first_half * (
 67    np.cos(angle_of_inclination) * horizontal_direction
 68    - np.sin(angle_of_inclination) * vertical_direction
 69)
 70
 71butterfly_rod = CosseratRod.straight_rod(
 72    n_elem,
 73    start=origin.reshape(3),
 74    direction=np.array([0.0, 0.0, 1.0]),
 75    normal=normal,
 76    base_length=total_length,
 77    base_radius=base_radius,
 78    density=density,
 79    nu=nu,
 80    youngs_modulus=youngs_modulus,
 81    shear_modulus=shear_modulus,
 82    position=positions,
 83)
 84
 85butterfly_sim.append(butterfly_rod)
 86
 87# Add call backs
 88class VelocityCallBack(CallBackBaseClass):
 89    """
 90    Call back function for continuum snake
 91    """
 92
 93    def __init__(self, step_skip: int, callback_params: dict):
 94        CallBackBaseClass.__init__(self)
 95        self.every = step_skip
 96        self.callback_params = callback_params
 97
 98    def make_callback(self, system, time, current_step: int):
 99
100        if current_step % self.every == 0:
101
102            self.callback_params["time"].append(time)
103            # Collect x
104            self.callback_params["position"].append(system.position_collection.copy())
105            # Collect energies as well
106            self.callback_params["te"].append(system.compute_translational_energy())
107            self.callback_params["re"].append(system.compute_rotational_energy())
108            self.callback_params["se"].append(system.compute_shear_energy())
109            self.callback_params["be"].append(system.compute_bending_energy())
110            return
111
112
113recorded_history = defaultdict(list)
114# initially record history
115recorded_history["time"].append(0.0)
116recorded_history["position"].append(butterfly_rod.position_collection.copy())
117recorded_history["te"].append(butterfly_rod.compute_translational_energy())
118recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
119recorded_history["se"].append(butterfly_rod.compute_shear_energy())
120recorded_history["be"].append(butterfly_rod.compute_bending_energy())
121
122butterfly_sim.collect_diagnostics(butterfly_rod).using(
123    VelocityCallBack, step_skip=100, callback_params=recorded_history
124)
125
126
127butterfly_sim.finalize()
128timestepper = PositionVerlet()
129# timestepper = PEFRL()
130
131dt = 0.01 * dl
132total_steps = int(final_time / dt)
133print("Total steps", total_steps)
134integrate(timestepper, butterfly_sim, final_time, total_steps)
135
136if PLOT_FIGURE:
137    # Plot the histories
138    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
139    ax = fig.add_subplot(111)
140    positions = recorded_history["position"]
141    # record first position
142    first_position = positions.pop(0)
143    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
144    n_positions = len(positions)
145    for i, pos in enumerate(positions):
146        alpha = np.exp(i / n_positions - 1)
147        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
148    # final position is also separate
149    last_position = positions.pop()
150    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
151    # don't block
152    fig.show()
153
154    # Plot the energies
155    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
156    energy_ax = energy_fig.add_subplot(111)
157    times = np.asarray(recorded_history["time"])
158    te = np.asarray(recorded_history["te"])
159    re = np.asarray(recorded_history["re"])
160    be = np.asarray(recorded_history["be"])
161    se = np.asarray(recorded_history["se"])
162
163    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
164    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
165    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
166    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
167    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
168    energy_ax.legend()
169    # don't block
170    energy_fig.show()
171
172    if SAVE_FIGURE:
173        fig.savefig("butterfly.png")
174        energy_fig.savefig("energies.png")
175
176    plt.show()
177
178if SAVE_RESULTS:
179    import pickle
180
181    filename = "butterfly_data.dat"
182    file = open(filename, "wb")
183    pickle.dump(butterfly_rod, file)
184    file.close()

Helical Buckling#

 1__doc__ = """Helical buckling validation case, for detailed explanation refer to 
 2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
 3
 4import numpy as np
 5import sys
 6
 7# FIXME without appending sys.path make it more generic
 8sys.path.append("../../")
 9from elastica import *
10from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
11    plot_helicalbuckling,
12)
13
14
15class HelicalBucklingSimulator(BaseSystemCollection, Constraints, Forcing):
16    pass
17
18
19helicalbuckling_sim = HelicalBucklingSimulator()
20
21# Options
22PLOT_FIGURE = True
23SAVE_FIGURE = True
24SAVE_RESULTS = False
25
26# setting up test params
27n_elem = 100
28start = np.zeros((3,))
29direction = np.array([0.0, 0.0, 1.0])
30normal = np.array([0.0, 1.0, 0.0])
31base_length = 100.0
32base_radius = 0.35
33base_area = np.pi * base_radius ** 2
34density = 1.0 / (base_area)
35nu = 0.01
36E = 1e6
37slack = 3
38number_of_rotations = 27
39# For shear modulus of 1e5, nu is 99!
40poisson_ratio = 9
41shear_modulus = E / (poisson_ratio + 1.0)
42shear_matrix = np.repeat(
43    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
44)
45temp_bend_matrix = np.zeros((3, 3))
46np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
47bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
48
49shearable_rod = CosseratRod.straight_rod(
50    n_elem,
51    start,
52    direction,
53    normal,
54    base_length,
55    base_radius,
56    density,
57    nu,
58    E,
59    shear_modulus=shear_modulus,
60)
61# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
62
63shearable_rod.shear_matrix = shear_matrix
64shearable_rod.bend_matrix = bend_matrix
65
66
67helicalbuckling_sim.append(shearable_rod)
68helicalbuckling_sim.constrain(shearable_rod).using(
69    HelicalBucklingBC,
70    constrained_position_idx=(0, -1),
71    constrained_director_idx=(0, -1),
72    twisting_time=500,
73    slack=slack,
74    number_of_rotations=number_of_rotations,
75)
76
77helicalbuckling_sim.finalize()
78timestepper = PositionVerlet()
79shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
80# timestepper = PEFRL()
81
82final_time = 10500.0
83dl = base_length / n_elem
84dt = 1e-3 * dl
85total_steps = int(final_time / dt)
86print("Total steps", total_steps)
87integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
88
89if PLOT_FIGURE:
90    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
91
92if SAVE_RESULTS:
93    import pickle
94
95    filename = "HelicalBuckling_data.dat"
96    file = open(filename, "wb")
97    pickle.dump(shearable_rod, file)
98    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import sys
  4import os
  5import numpy as np
  6
  7sys.path.append("../../")
  8from elastica import *
  9
 10from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
 11    plot_snake_velocity,
 12    plot_video,
 13    compute_projected_velocity,
 14    plot_curvature,
 15)
 16
 17
 18class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
 19    pass
 20
 21
 22def run_snake(
 23    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
 24):
 25    # Initialize the simulation class
 26    snake_sim = SnakeSimulator()
 27
 28    # Simulation parameters
 29    period = 2
 30    final_time = (11.0 + 0.01) * period
 31    time_step = 8e-6
 32    total_steps = int(final_time / time_step)
 33    rendering_fps = 60
 34    step_skip = int(1.0 / (rendering_fps * time_step))
 35
 36    # setting up test params
 37    n_elem = 50
 38    start = np.zeros((3,))
 39    direction = np.array([0.0, 0.0, 1.0])
 40    normal = np.array([0.0, 1.0, 0.0])
 41    base_length = 0.35
 42    base_radius = base_length * 0.011
 43    density = 1000
 44    nu = 1e-4
 45    E = 1e6
 46    poisson_ratio = 0.5
 47    shear_modulus = E / (poisson_ratio + 1.0)
 48
 49    shearable_rod = CosseratRod.straight_rod(
 50        n_elem,
 51        start,
 52        direction,
 53        normal,
 54        base_length,
 55        base_radius,
 56        density,
 57        nu,
 58        E,
 59        shear_modulus=shear_modulus,
 60    )
 61
 62    snake_sim.append(shearable_rod)
 63
 64    # Add gravitational forces
 65    gravitational_acc = -9.80665
 66    snake_sim.add_forcing_to(shearable_rod).using(
 67        GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 68    )
 69
 70    # Add muscle torques
 71    wave_length = b_coeff[-1]
 72    snake_sim.add_forcing_to(shearable_rod).using(
 73        MuscleTorques,
 74        base_length=base_length,
 75        b_coeff=b_coeff[:-1],
 76        period=period,
 77        wave_number=2.0 * np.pi / (wave_length),
 78        phase_shift=0.0,
 79        rest_lengths=shearable_rod.rest_lengths,
 80        ramp_up_time=period,
 81        direction=normal,
 82        with_spline=True,
 83    )
 84
 85    # Add friction forces
 86    origin_plane = np.array([0.0, -base_radius, 0.0])
 87    normal_plane = normal
 88    slip_velocity_tol = 1e-8
 89    froude = 0.1
 90    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 91    kinetic_mu_array = np.array(
 92        [mu, 1.5 * mu, 2.0 * mu]
 93    )  # [forward, backward, sideways]
 94    static_mu_array = np.zeros(kinetic_mu_array.shape)
 95    snake_sim.add_forcing_to(shearable_rod).using(
 96        AnisotropicFrictionalPlane,
 97        k=1.0,
 98        nu=1e-6,
 99        plane_origin=origin_plane,
100        plane_normal=normal_plane,
101        slip_velocity_tol=slip_velocity_tol,
102        static_mu_array=static_mu_array,
103        kinetic_mu_array=kinetic_mu_array,
104    )
105
106    # Add call backs
107    class ContinuumSnakeCallBack(CallBackBaseClass):
108        """
109        Call back function for continuum snake
110        """
111
112        def __init__(self, step_skip: int, callback_params: dict):
113            CallBackBaseClass.__init__(self)
114            self.every = step_skip
115            self.callback_params = callback_params
116
117        def make_callback(self, system, time, current_step: int):
118
119            if current_step % self.every == 0:
120
121                self.callback_params["time"].append(time)
122                self.callback_params["step"].append(current_step)
123                self.callback_params["position"].append(
124                    system.position_collection.copy()
125                )
126                self.callback_params["velocity"].append(
127                    system.velocity_collection.copy()
128                )
129                self.callback_params["avg_velocity"].append(
130                    system.compute_velocity_center_of_mass()
131                )
132
133                self.callback_params["center_of_mass"].append(
134                    system.compute_position_center_of_mass()
135                )
136                self.callback_params["curvature"].append(system.kappa.copy())
137
138                return
139
140    pp_list = defaultdict(list)
141    snake_sim.collect_diagnostics(shearable_rod).using(
142        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
143    )
144
145    snake_sim.finalize()
146
147    timestepper = PositionVerlet()
148    integrate(timestepper, snake_sim, final_time, total_steps)
149
150    if PLOT_FIGURE:
151        filename_plot = "continuum_snake_velocity.png"
152        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
153        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
154
155        if SAVE_VIDEO:
156            filename_video = "continuum_snake.mp4"
157            plot_video(
158                pp_list,
159                video_name=filename_video,
160                fps=rendering_fps,
161                xlim=(0, 4),
162                ylim=(-1, 1),
163            )
164
165    if SAVE_RESULTS:
166        import pickle
167
168        filename = "continuum_snake.dat"
169        file = open(filename, "wb")
170        pickle.dump(pp_list, file)
171        file.close()
172
173    # Compute the average forward velocity. These will be used for optimization.
174    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
175
176    return avg_forward, avg_lateral, pp_list
177
178
179if __name__ == "__main__":
180
181    # Options
182    PLOT_FIGURE = True
183    SAVE_FIGURE = True
184    SAVE_VIDEO = True
185    SAVE_RESULTS = False
186    CMA_OPTION = False
187
188    if CMA_OPTION:
189        import cma
190
191        SAVE_OPTIMIZED_COEFFICIENTS = False
192
193        def optimize_snake(spline_coefficient):
194            [avg_forward, _, _] = run_snake(
195                spline_coefficient,
196                PLOT_FIGURE=False,
197                SAVE_FIGURE=False,
198                SAVE_VIDEO=False,
199                SAVE_RESULTS=False,
200            )
201            return -avg_forward
202
203        # Optimize snake for forward velocity. In cma.fmin first input is function
204        # to be optimized, second input is initial guess for coefficients you are optimizing
205        # for and third input is standard deviation you initially set.
206        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
207
208        # Save the optimized coefficients to a file
209        filename_data = "optimized_coefficients.txt"
210        if SAVE_OPTIMIZED_COEFFICIENTS:
211            assert filename_data != "", "provide a file name for coefficients"
212            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
213
214    else:
215        # Add muscle forces on the rod
216        if os.path.exists("optimized_coefficients.txt"):
217            t_coeff_optimized = np.genfromtxt(
218                "optimized_coefficients.txt", delimiter=","
219            )
220        else:
221            wave_length = 1.0
222            t_coeff_optimized = np.array(
223                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
224            )
225            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
226
227        # run the simulation
228        [avg_forward, avg_lateral, pp_list] = run_snake(
229            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
230        )
231
232        print("average forward velocity:", avg_forward)
233        print("average forward lateral:", avg_lateral)