# Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

## Axial Stretching#

```  1""" Axial stretching test-case
2
3    Assume we have a rod lying aligned in the x-direction, with high internal
4    damping.
5
6    We fix one end (say, the left end) of the rod to a wall. On the right
7    end we apply a force directed axially pulling the rods tip. Linear
8    theory (assuming small displacements) predict that the net displacement
9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
10    usual meaning (the rod is just a linear spring). We compare our results
11    with the above result.
12
13    We can "improve" the theory by having a better estimate for the rod's
14    spring constant by assuming that it equilibriates under the new position,
15    with
16    Δx = F * (L + Δx)/ (A * E)
17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
18    this position.
19
20    Note that if the damping is not high, the rod oscillates about the eventual
21    resting position (and this agrees with the theoretical predictions without
22    any damping : we should see the rod oscillating simple-harmonically in time).
23
24    isort:skip_file
25"""
26import numpy as np
27from matplotlib import pyplot as plt
28
29from elastica import *
30
31
32class StretchingBeamSimulator(
33    BaseSystemCollection, Constraints, Forcing, Damping, CallBacks
34):
35    pass
36
37
38stretch_sim = StretchingBeamSimulator()
39final_time = 200.0
40
41# Options
42PLOT_FIGURE = True
43SAVE_FIGURE = False
44SAVE_RESULTS = False
45
46# setting up test params
47n_elem = 19
48start = np.zeros((3,))
49direction = np.array([1.0, 0.0, 0.0])
50normal = np.array([0.0, 1.0, 0.0])
51base_length = 1.0
53base_area = np.pi * base_radius ** 2
54density = 1000
55youngs_modulus = 1e4
56# For shear modulus of 1e4, nu is 99!
57poisson_ratio = 0.5
58shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
59
60stretchable_rod = CosseratRod.straight_rod(
61    n_elem,
62    start,
63    direction,
64    normal,
65    base_length,
67    density,
68    0.0,  # internal damping constant, deprecated in v0.3.0
69    youngs_modulus,
70    shear_modulus=shear_modulus,
71)
72
73stretch_sim.append(stretchable_rod)
74stretch_sim.constrain(stretchable_rod).using(
75    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
76)
77
78end_force_x = 1.0
79end_force = np.array([end_force_x, 0.0, 0.0])
81    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
82)
83
85dl = base_length / n_elem
86# old damping model (deprecated in v0.3.0) values
87# dt = 0.01 * dl
88# damping_constant = 1.0
89dt = 0.1 * dl
90damping_constant = 0.1
91stretch_sim.dampen(stretchable_rod).using(
92    AnalyticalLinearDamper,
93    damping_constant=damping_constant,
94    time_step=dt,
95)
96
98class AxialStretchingCallBack(CallBackBaseClass):
99    """
100    Tracks the velocity norms of the rod
101    """
102
103    def __init__(self, step_skip: int, callback_params: dict):
104        CallBackBaseClass.__init__(self)
105        self.every = step_skip
106        self.callback_params = callback_params
107
108    def make_callback(self, system, time, current_step: int):
109
110        if current_step % self.every == 0:
111
112            self.callback_params["time"].append(time)
113            # Collect only x
114            self.callback_params["position"].append(
115                system.position_collection[0, -1].copy()
116            )
117            self.callback_params["velocity_norms"].append(
118                np.linalg.norm(system.velocity_collection.copy())
119            )
120            return
121
122
123recorded_history = defaultdict(list)
124stretch_sim.collect_diagnostics(stretchable_rod).using(
125    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
126)
127
128stretch_sim.finalize()
129timestepper = PositionVerlet()
130# timestepper = PEFRL()
131
132total_steps = int(final_time / dt)
133print("Total steps", total_steps)
134integrate(timestepper, stretch_sim, final_time, total_steps)
135
136if PLOT_FIGURE:
137    # First-order theory with base-length
138    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
139    # First-order theory with modified-length, gives better estimates
140    expected_tip_disp_improved = (
141        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
142    )
143
144    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
146    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
147    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
148    ax.hlines(
149        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
150    )
151    if SAVE_FIGURE:
152        fig.savefig("axial_stretching.pdf")
153    plt.show()
154
155if SAVE_RESULTS:
156    import pickle
157
158    filename = "axial_stretching_data.dat"
159    file = open(filename, "wb")
160    pickle.dump(stretchable_rod, file)
161    file.close()
162
163    tv = (
164        np.asarray(recorded_history["time"]),
165        np.asarray(recorded_history["velocity_norms"]),
166    )
167
168    def as_time_series(v):
169        return v.T
170
171    np.savetxt(
172        "velocity_norms.csv",
173        as_time_series(np.stack(tv)),
174        delimiter=",",
175    )
```

## Timoshenko#

```  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
3
4import numpy as np
5from elastica import *
6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
7
8
9class TimoshenkoBeamSimulator(
10    BaseSystemCollection, Constraints, Forcing, CallBacks, Damping
11):
12    pass
13
14
15timoshenko_sim = TimoshenkoBeamSimulator()
16final_time = 5000.0
17
18# Options
19PLOT_FIGURE = True
20SAVE_FIGURE = True
21SAVE_RESULTS = False
23
24# setting up test params
25n_elem = 100
26start = np.zeros((3,))
27direction = np.array([0.0, 0.0, 1.0])
28normal = np.array([0.0, 1.0, 0.0])
29base_length = 3.0
31base_area = np.pi * base_radius ** 2
32density = 5000
33nu = 0.1 / 7 / density / base_area
34E = 1e6
35# For shear modulus of 1e4, nu is 99!
36poisson_ratio = 99
37shear_modulus = E / (poisson_ratio + 1.0)
38
39shearable_rod = CosseratRod.straight_rod(
40    n_elem,
41    start,
42    direction,
43    normal,
44    base_length,
46    density,
47    0.0,  # internal damping constant, deprecated in v0.3.0
48    E,
49    shear_modulus=shear_modulus,
50)
51
52timoshenko_sim.append(shearable_rod)
54dl = base_length / n_elem
55dt = 0.07 * dl
56timoshenko_sim.dampen(shearable_rod).using(
57    AnalyticalLinearDamper,
58    damping_constant=nu,
59    time_step=dt,
60)
61
62timoshenko_sim.constrain(shearable_rod).using(
63    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
64)
65
66end_force = np.array([-15.0, 0.0, 0.0])
68    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
69)
70
71
73    # Start into the plane
74    unshearable_start = np.array([0.0, -1.0, 0.0])
75    shear_modulus = E / (-0.7 + 1.0)
76    unshearable_rod = CosseratRod.straight_rod(
77        n_elem,
78        unshearable_start,
79        direction,
80        normal,
81        base_length,
83        density,
84        0.0,  # internal damping constant, deprecated in v0.3.0
85        E,
86        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
87        shear_modulus=shear_modulus,
88    )
89
90    timoshenko_sim.append(unshearable_rod)
91
93    timoshenko_sim.dampen(unshearable_rod).using(
94        AnalyticalLinearDamper,
95        damping_constant=nu,
96        time_step=dt,
97    )
98    timoshenko_sim.constrain(unshearable_rod).using(
99        OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
100    )
102        EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
103    )
104
106class VelocityCallBack(CallBackBaseClass):
107    """
108    Tracks the velocity norms of the rod
109    """
110
111    def __init__(self, step_skip: int, callback_params: dict):
112        CallBackBaseClass.__init__(self)
113        self.every = step_skip
114        self.callback_params = callback_params
115
116    def make_callback(self, system, time, current_step: int):
117
118        if current_step % self.every == 0:
119
120            self.callback_params["time"].append(time)
121            # Collect x
122            self.callback_params["velocity_norms"].append(
123                np.linalg.norm(system.velocity_collection.copy())
124            )
125            return
126
127
128recorded_history = defaultdict(list)
129timoshenko_sim.collect_diagnostics(shearable_rod).using(
130    VelocityCallBack, step_skip=500, callback_params=recorded_history
131)
132
133timoshenko_sim.finalize()
134timestepper = PositionVerlet()
135# timestepper = PEFRL()
136
137total_steps = int(final_time / dt)
138print("Total steps", total_steps)
139integrate(timestepper, timoshenko_sim, final_time, total_steps)
140
141if PLOT_FIGURE:
143
144if SAVE_RESULTS:
145    import pickle
146
147    filename = "Timoshenko_beam_data.dat"
148    file = open(filename, "wb")
149    pickle.dump(shearable_rod, file)
150    file.close()
151
152    tv = (
153        np.asarray(recorded_history["time"]),
154        np.asarray(recorded_history["velocity_norms"]),
155    )
156
157    def as_time_series(v):
158        return v.T
159
160    np.savetxt(
161        "velocity_norms.csv",
162        as_time_series(np.stack(tv)),
163        delimiter=",",
164    )
```

## Butterfly#

```  1import numpy as np
2from matplotlib import pyplot as plt
3from matplotlib.colors import to_rgb
4
5
6from elastica import *
7from elastica.utils import MaxDimension
8
9
10class ButterflySimulator(BaseSystemCollection, CallBacks):
11    pass
12
13
14butterfly_sim = ButterflySimulator()
15final_time = 40.0
16
17# Options
18PLOT_FIGURE = True
19SAVE_FIGURE = True
20SAVE_RESULTS = True
22
23# setting up test params
24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
25n_elem = 4  # Change based on requirements, but be careful
26n_elem += n_elem % 2
27half_n_elem = n_elem // 2
28
29origin = np.zeros((3, 1))
31
32# in-plane
33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
35
36# out-of-plane
37normal = np.array([0.0, 1.0, 0.0])
38
39total_length = 3.0
41base_area = np.pi * base_radius ** 2
42density = 5000
43youngs_modulus = 1e4
44poisson_ratio = 0.5
45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
46
47positions = np.empty((MaxDimension.value(), n_elem + 1))
48dl = total_length / n_elem
49
50# First half of positions stem from slope angle_of_inclination
51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
53    np.cos(angle_of_inclination) * horizontal_direction
54    + np.sin(angle_of_inclination) * vertical_direction
55)
56positions[..., half_n_elem:] = positions[
57    ..., half_n_elem : half_n_elem + 1
58] + dl * first_half * (
59    np.cos(angle_of_inclination) * horizontal_direction
60    - np.sin(angle_of_inclination) * vertical_direction
61)
62
63butterfly_rod = CosseratRod.straight_rod(
64    n_elem,
65    start=origin.reshape(3),
66    direction=np.array([0.0, 0.0, 1.0]),
67    normal=normal,
68    base_length=total_length,
70    density=density,
71    nu=0.0,  # internal damping constant, deprecated in v0.3.0
72    youngs_modulus=youngs_modulus,
73    shear_modulus=shear_modulus,
74    position=positions,
75)
76
77butterfly_sim.append(butterfly_rod)
78
80class VelocityCallBack(CallBackBaseClass):
81    """
82    Call back function for continuum snake
83    """
84
85    def __init__(self, step_skip: int, callback_params: dict):
86        CallBackBaseClass.__init__(self)
87        self.every = step_skip
88        self.callback_params = callback_params
89
90    def make_callback(self, system, time, current_step: int):
91
92        if current_step % self.every == 0:
93
94            self.callback_params["time"].append(time)
95            # Collect x
96            self.callback_params["position"].append(system.position_collection.copy())
97            # Collect energies as well
98            self.callback_params["te"].append(system.compute_translational_energy())
99            self.callback_params["re"].append(system.compute_rotational_energy())
100            self.callback_params["se"].append(system.compute_shear_energy())
101            self.callback_params["be"].append(system.compute_bending_energy())
102            return
103
104
105recorded_history = defaultdict(list)
106# initially record history
107recorded_history["time"].append(0.0)
108recorded_history["position"].append(butterfly_rod.position_collection.copy())
109recorded_history["te"].append(butterfly_rod.compute_translational_energy())
110recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
111recorded_history["se"].append(butterfly_rod.compute_shear_energy())
112recorded_history["be"].append(butterfly_rod.compute_bending_energy())
113
114butterfly_sim.collect_diagnostics(butterfly_rod).using(
115    VelocityCallBack, step_skip=100, callback_params=recorded_history
116)
117
118
119butterfly_sim.finalize()
120timestepper = PositionVerlet()
121# timestepper = PEFRL()
122
123dt = 0.01 * dl
124total_steps = int(final_time / dt)
125print("Total steps", total_steps)
126integrate(timestepper, butterfly_sim, final_time, total_steps)
127
128if PLOT_FIGURE:
129    # Plot the histories
130    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
132    positions = recorded_history["position"]
133    # record first position
134    first_position = positions.pop(0)
135    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
136    n_positions = len(positions)
137    for i, pos in enumerate(positions):
138        alpha = np.exp(i / n_positions - 1)
139        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
140    # final position is also separate
141    last_position = positions.pop()
142    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
143    # don't block
144    fig.show()
145
146    # Plot the energies
147    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
149    times = np.asarray(recorded_history["time"])
150    te = np.asarray(recorded_history["te"])
151    re = np.asarray(recorded_history["re"])
152    be = np.asarray(recorded_history["be"])
153    se = np.asarray(recorded_history["se"])
154
155    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
156    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
157    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
158    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
159    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
160    energy_ax.legend()
161    # don't block
162    energy_fig.show()
163
164    if SAVE_FIGURE:
165        fig.savefig("butterfly.png")
166        energy_fig.savefig("energies.png")
167
168    plt.show()
169
170if SAVE_RESULTS:
171    import pickle
172
173    filename = "butterfly_data.dat"
174    file = open(filename, "wb")
175    pickle.dump(butterfly_rod, file)
176    file.close()
```

## Helical Buckling#

```  1__doc__ = """Helical buckling validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
3
4import numpy as np
5from elastica import *
6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
7    plot_helicalbuckling,
8)
9
10
11class HelicalBucklingSimulator(BaseSystemCollection, Constraints, Damping, Forcing):
12    pass
13
14
15helicalbuckling_sim = HelicalBucklingSimulator()
16
17# Options
18PLOT_FIGURE = True
19SAVE_FIGURE = True
20SAVE_RESULTS = False
21
22# setting up test params
23n_elem = 100
24start = np.zeros((3,))
25direction = np.array([0.0, 0.0, 1.0])
26normal = np.array([0.0, 1.0, 0.0])
27base_length = 100.0
29base_area = np.pi * base_radius ** 2
30density = 1.0 / (base_area)
31nu = 0.01 / density / base_area
32E = 1e6
33slack = 3
34number_of_rotations = 27
35# For shear modulus of 1e5, nu is 99!
36poisson_ratio = 9
37shear_modulus = E / (poisson_ratio + 1.0)
38shear_matrix = np.repeat(
39    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
40)
41temp_bend_matrix = np.zeros((3, 3))
42np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
43bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
44
45shearable_rod = CosseratRod.straight_rod(
46    n_elem,
47    start,
48    direction,
49    normal,
50    base_length,
52    density,
53    0.0,  # internal damping constant, deprecated in v0.3.0
54    E,
55    shear_modulus=shear_modulus,
56)
57# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
58
59shearable_rod.shear_matrix = shear_matrix
60shearable_rod.bend_matrix = bend_matrix
61
62
63helicalbuckling_sim.append(shearable_rod)
65dl = base_length / n_elem
66dt = 1e-3 * dl
67helicalbuckling_sim.dampen(shearable_rod).using(
68    AnalyticalLinearDamper,
69    damping_constant=nu,
70    time_step=dt,
71)
72
73helicalbuckling_sim.constrain(shearable_rod).using(
74    HelicalBucklingBC,
75    constrained_position_idx=(0, -1),
76    constrained_director_idx=(0, -1),
77    twisting_time=500,
78    slack=slack,
79    number_of_rotations=number_of_rotations,
80)
81
82helicalbuckling_sim.finalize()
83timestepper = PositionVerlet()
84shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
85# timestepper = PEFRL()
86
87final_time = 10500.0
88total_steps = int(final_time / dt)
89print("Total steps", total_steps)
90integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
91
92if PLOT_FIGURE:
93    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
94
95if SAVE_RESULTS:
96    import pickle
97
98    filename = "HelicalBuckling_data.dat"
99    file = open(filename, "wb")
100    pickle.dump(shearable_rod, file)
101    file.close()
```

## Continuum Snake#

```  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
2
3import os
4import numpy as np
5from elastica import *
6
7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
8    plot_snake_velocity,
9    plot_video,
10    compute_projected_velocity,
11    plot_curvature,
12)
13
14
15class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, Damping, CallBacks):
16    pass
17
18
19def run_snake(
20    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
21):
22    # Initialize the simulation class
23    snake_sim = SnakeSimulator()
24
25    # Simulation parameters
26    period = 2
27    final_time = (11.0 + 0.01) * period
28
29    # setting up test params
30    n_elem = 50
31    start = np.zeros((3,))
32    direction = np.array([0.0, 0.0, 1.0])
33    normal = np.array([0.0, 1.0, 0.0])
34    base_length = 0.35
35    base_radius = base_length * 0.011
36    density = 1000
37    E = 1e6
38    poisson_ratio = 0.5
39    shear_modulus = E / (poisson_ratio + 1.0)
40
41    shearable_rod = CosseratRod.straight_rod(
42        n_elem,
43        start,
44        direction,
45        normal,
46        base_length,
48        density,
49        0.0,  # internal damping constant, deprecated in v0.3.0
50        E,
51        shear_modulus=shear_modulus,
52    )
53
54    snake_sim.append(shearable_rod)
55
57    gravitational_acc = -9.80665
59        GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
60    )
61
63    wave_length = b_coeff[-1]
65        MuscleTorques,
66        base_length=base_length,
67        b_coeff=b_coeff[:-1],
68        period=period,
69        wave_number=2.0 * np.pi / (wave_length),
70        phase_shift=0.0,
71        rest_lengths=shearable_rod.rest_lengths,
72        ramp_up_time=period,
73        direction=normal,
74        with_spline=True,
75    )
76
78    origin_plane = np.array([0.0, -base_radius, 0.0])
79    normal_plane = normal
80    slip_velocity_tol = 1e-8
81    froude = 0.1
82    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
83    kinetic_mu_array = np.array(
84        [mu, 1.5 * mu, 2.0 * mu]
85    )  # [forward, backward, sideways]
86    static_mu_array = np.zeros(kinetic_mu_array.shape)
88        AnisotropicFrictionalPlane,
89        k=1.0,
90        nu=1e-6,
91        plane_origin=origin_plane,
92        plane_normal=normal_plane,
93        slip_velocity_tol=slip_velocity_tol,
94        static_mu_array=static_mu_array,
95        kinetic_mu_array=kinetic_mu_array,
96    )
97
99    # old damping model (deprecated in v0.3.0) values
100    # damping_constant = 2e-3
101    # time_step = 8e-6
102    damping_constant = 2e-3
103    time_step = 1e-4
104    snake_sim.dampen(shearable_rod).using(
105        AnalyticalLinearDamper,
106        damping_constant=damping_constant,
107        time_step=time_step,
108    )
109
110    total_steps = int(final_time / time_step)
111    rendering_fps = 60
112    step_skip = int(1.0 / (rendering_fps * time_step))
113
115    class ContinuumSnakeCallBack(CallBackBaseClass):
116        """
117        Call back function for continuum snake
118        """
119
120        def __init__(self, step_skip: int, callback_params: dict):
121            CallBackBaseClass.__init__(self)
122            self.every = step_skip
123            self.callback_params = callback_params
124
125        def make_callback(self, system, time, current_step: int):
126
127            if current_step % self.every == 0:
128
129                self.callback_params["time"].append(time)
130                self.callback_params["step"].append(current_step)
131                self.callback_params["position"].append(
132                    system.position_collection.copy()
133                )
134                self.callback_params["velocity"].append(
135                    system.velocity_collection.copy()
136                )
137                self.callback_params["avg_velocity"].append(
138                    system.compute_velocity_center_of_mass()
139                )
140
141                self.callback_params["center_of_mass"].append(
142                    system.compute_position_center_of_mass()
143                )
144                self.callback_params["curvature"].append(system.kappa.copy())
145
146                return
147
148    pp_list = defaultdict(list)
149    snake_sim.collect_diagnostics(shearable_rod).using(
150        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
151    )
152
153    snake_sim.finalize()
154
155    timestepper = PositionVerlet()
156    integrate(timestepper, snake_sim, final_time, total_steps)
157
158    if PLOT_FIGURE:
159        filename_plot = "continuum_snake_velocity.png"
160        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
161        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
162
163        if SAVE_VIDEO:
164            filename_video = "continuum_snake.mp4"
165            plot_video(
166                pp_list,
167                video_name=filename_video,
168                fps=rendering_fps,
169                xlim=(0, 4),
170                ylim=(-1, 1),
171            )
172
173    if SAVE_RESULTS:
174        import pickle
175
176        filename = "continuum_snake.dat"
177        file = open(filename, "wb")
178        pickle.dump(pp_list, file)
179        file.close()
180
181    # Compute the average forward velocity. These will be used for optimization.
182    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
183
184    return avg_forward, avg_lateral, pp_list
185
186
187if __name__ == "__main__":
188
189    # Options
190    PLOT_FIGURE = True
191    SAVE_FIGURE = True
192    SAVE_VIDEO = True
193    SAVE_RESULTS = False
194    CMA_OPTION = False
195
196    if CMA_OPTION:
197        import cma
198
199        SAVE_OPTIMIZED_COEFFICIENTS = False
200
201        def optimize_snake(spline_coefficient):
202            [avg_forward, _, _] = run_snake(
203                spline_coefficient,
204                PLOT_FIGURE=False,
205                SAVE_FIGURE=False,
206                SAVE_VIDEO=False,
207                SAVE_RESULTS=False,
208            )
209            return -avg_forward
210
211        # Optimize snake for forward velocity. In cma.fmin first input is function
212        # to be optimized, second input is initial guess for coefficients you are optimizing
213        # for and third input is standard deviation you initially set.
214        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
215
216        # Save the optimized coefficients to a file
217        filename_data = "optimized_coefficients.txt"
218        if SAVE_OPTIMIZED_COEFFICIENTS:
219            assert filename_data != "", "provide a file name for coefficients"
220            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
221
222    else:
223        # Add muscle forces on the rod
224        if os.path.exists("optimized_coefficients.txt"):
225            t_coeff_optimized = np.genfromtxt(
226                "optimized_coefficients.txt", delimiter=","
227            )
228        else:
229            wave_length = 1.0
230            t_coeff_optimized = np.array(
231                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
232            )
233            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
234
235        # run the simulation
236        [avg_forward, avg_lateral, pp_list] = run_snake(
237            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
238        )
239
240        print("average forward velocity:", avg_forward)
241        print("average forward lateral:", avg_lateral)
```