Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1""" Axial stretching test-case
  2
  3    Assume we have a rod lying aligned in the x-direction, with high internal
  4    damping.
  5
  6    We fix one end (say, the left end) of the rod to a wall. On the right
  7    end we apply a force directed axially pulling the rods tip. Linear
  8    theory (assuming small displacements) predict that the net displacement
  9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10    usual meaning (the rod is just a linear spring). We compare our results
 11    with the above result.
 12
 13    We can "improve" the theory by having a better estimate for the rod's
 14    spring constant by assuming that it equilibriates under the new position,
 15    with
 16    Δx = F * (L + Δx)/ (A * E)
 17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18    this position.
 19
 20    Note that if the damping is not high, the rod oscillates about the eventual
 21    resting position (and this agrees with the theoretical predictions without
 22    any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24    isort:skip_file
 25"""
 26# FIXME without appending sys.path make it more generic
 27import sys
 28
 29sys.path.append("../../")  # isort:skip
 30
 31# from collections import defaultdict
 32
 33# import numpy as np
 34from matplotlib import pyplot as plt
 35
 36from elastica import *
 37
 38
 39class StretchingBeamSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
 40    pass
 41
 42
 43stretch_sim = StretchingBeamSimulator()
 44final_time = 20.0
 45
 46# Options
 47PLOT_FIGURE = True
 48SAVE_FIGURE = False
 49SAVE_RESULTS = False
 50
 51# setting up test params
 52n_elem = 19
 53start = np.zeros((3,))
 54direction = np.array([1.0, 0.0, 0.0])
 55normal = np.array([0.0, 1.0, 0.0])
 56base_length = 1.0
 57base_radius = 0.025
 58base_area = np.pi * base_radius ** 2
 59density = 1000
 60nu = 2.0
 61youngs_modulus = 1e4
 62# For shear modulus of 1e4, nu is 99!
 63poisson_ratio = 0.5
 64
 65stretchable_rod = CosseratRod.straight_rod(
 66    n_elem,
 67    start,
 68    direction,
 69    normal,
 70    base_length,
 71    base_radius,
 72    density,
 73    nu,
 74    youngs_modulus,
 75    poisson_ratio,
 76)
 77
 78stretch_sim.append(stretchable_rod)
 79stretch_sim.constrain(stretchable_rod).using(
 80    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 81)
 82
 83end_force_x = 1.0
 84end_force = np.array([end_force_x, 0.0, 0.0])
 85stretch_sim.add_forcing_to(stretchable_rod).using(
 86    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 87)
 88
 89# Add call backs
 90class AxialStretchingCallBack(CallBackBaseClass):
 91    """
 92    Call back function for continuum snake
 93    """
 94
 95    def __init__(self, step_skip: int, callback_params: dict):
 96        CallBackBaseClass.__init__(self)
 97        self.every = step_skip
 98        self.callback_params = callback_params
 99
100    def make_callback(self, system, time, current_step: int):
101
102        if current_step % self.every == 0:
103
104            self.callback_params["time"].append(time)
105            # Collect only x
106            self.callback_params["position"].append(
107                system.position_collection[0, -1].copy()
108            )
109            return
110
111
112recorded_history = defaultdict(list)
113stretch_sim.collect_diagnostics(stretchable_rod).using(
114    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
115)
116
117stretch_sim.finalize()
118timestepper = PositionVerlet()
119# timestepper = PEFRL()
120
121dl = base_length / n_elem
122dt = 0.01 * dl
123total_steps = int(final_time / dt)
124print("Total steps", total_steps)
125integrate(timestepper, stretch_sim, final_time, total_steps)
126
127if PLOT_FIGURE:
128    # First-order theory with base-length
129    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
130    # First-order theory with modified-length, gives better estimates
131    expected_tip_disp_improved = (
132        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
133    )
134
135    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
136    ax = fig.add_subplot(111)
137    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
138    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
139    ax.hlines(
140        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
141    )
142    if SAVE_FIGURE:
143        fig.savefig("axial_stretching.pdf")
144    plt.show()
145
146if SAVE_RESULTS:
147    import pickle
148
149    filename = "axial_stretching_data.dat"
150    file = open(filename, "wb")
151    pickle.dump(stretchable_rod, file)
152    file.close()

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to 
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5import sys
  6
  7# FIXME without appending sys.path make it more generic
  8sys.path.append("../../")
  9from elastica import *
 10from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
 11
 12
 13class TimoshenkoBeamSimulator(BaseSystemCollection, Constraints, Forcing):
 14    pass
 15
 16
 17timoshenko_sim = TimoshenkoBeamSimulator()
 18final_time = 5000
 19
 20# Options
 21PLOT_FIGURE = True
 22SAVE_FIGURE = False
 23SAVE_RESULTS = False
 24ADD_UNSHEARABLE_ROD = True
 25
 26# setting up test params
 27n_elem = 100
 28start = np.zeros((3,))
 29direction = np.array([0.0, 0.0, 1.0])
 30normal = np.array([0.0, 1.0, 0.0])
 31base_length = 3.0
 32base_radius = 0.25
 33base_area = np.pi * base_radius ** 2
 34density = 5000
 35nu = 0.1
 36E = 1e6
 37# For shear modulus of 1e4, nu is 99!
 38poisson_ratio = 99
 39shear_modulus = E / (poisson_ratio + 1.0)
 40
 41shearable_rod = CosseratRod.straight_rod(
 42    n_elem,
 43    start,
 44    direction,
 45    normal,
 46    base_length,
 47    base_radius,
 48    density,
 49    nu,
 50    E,
 51    shear_modulus=shear_modulus,
 52)
 53
 54timoshenko_sim.append(shearable_rod)
 55timoshenko_sim.constrain(shearable_rod).using(
 56    OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 57)
 58
 59end_force = np.array([-15.0, 0.0, 0.0])
 60timoshenko_sim.add_forcing_to(shearable_rod).using(
 61    EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 62)
 63
 64
 65if ADD_UNSHEARABLE_ROD:
 66    # Start into the plane
 67    unshearable_start = np.array([0.0, -1.0, 0.0])
 68    shear_modulus = E / (-0.7 + 1.0)
 69    unshearable_rod = CosseratRod.straight_rod(
 70        n_elem,
 71        unshearable_start,
 72        direction,
 73        normal,
 74        base_length,
 75        base_radius,
 76        density,
 77        nu,
 78        E,
 79        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 80        shear_modulus=shear_modulus,
 81    )
 82
 83    timoshenko_sim.append(unshearable_rod)
 84    timoshenko_sim.constrain(unshearable_rod).using(
 85        OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 86    )
 87    timoshenko_sim.add_forcing_to(unshearable_rod).using(
 88        EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 89    )
 90
 91timoshenko_sim.finalize()
 92timestepper = PositionVerlet()
 93# timestepper = PEFRL()
 94
 95dl = base_length / n_elem
 96dt = 0.01 * dl
 97total_steps = int(final_time / dt)
 98print("Total steps", total_steps)
 99integrate(timestepper, timoshenko_sim, final_time, total_steps)
100
101if PLOT_FIGURE:
102    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
103
104if SAVE_RESULTS:
105    import pickle
106
107    filename = "Timoshenko_beam_data.dat"
108    file = open(filename, "wb")
109    pickle.dump(shearable_rod, file)
110    file.close()

Butterfly#

  1# FIXME without appending sys.path make it more generic
  2import sys
  3
  4sys.path.append("../")
  5sys.path.append("../../")
  6
  7# from collections import defaultdict
  8import numpy as np
  9from matplotlib import pyplot as plt
 10from matplotlib.colors import to_rgb
 11
 12
 13from elastica import *
 14from elastica.utils import MaxDimension
 15
 16
 17class ButterflySimulator(BaseSystemCollection, CallBacks):
 18    pass
 19
 20
 21butterfly_sim = ButterflySimulator()
 22final_time = 40.0
 23
 24# Options
 25PLOT_FIGURE = True
 26SAVE_FIGURE = True
 27SAVE_RESULTS = True
 28ADD_UNSHEARABLE_ROD = False
 29
 30# setting up test params
 31# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 32n_elem = 4  # Change based on requirements, but be careful
 33n_elem += n_elem % 2
 34half_n_elem = n_elem // 2
 35
 36origin = np.zeros((3, 1))
 37angle_of_inclination = np.deg2rad(45.0)
 38
 39# in-plane
 40horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 41vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 42
 43# out-of-plane
 44normal = np.array([0.0, 1.0, 0.0])
 45
 46total_length = 3.0
 47base_radius = 0.25
 48base_area = np.pi * base_radius ** 2
 49density = 5000
 50nu = 0.0
 51youngs_modulus = 1e4
 52poisson_ratio = 0.5
 53
 54positions = np.empty((MaxDimension.value(), n_elem + 1))
 55dl = total_length / n_elem
 56
 57# First half of positions stem from slope angle_of_inclination
 58first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 59positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 60    np.cos(angle_of_inclination) * horizontal_direction
 61    + np.sin(angle_of_inclination) * vertical_direction
 62)
 63positions[..., half_n_elem:] = positions[
 64    ..., half_n_elem : half_n_elem + 1
 65] + dl * first_half * (
 66    np.cos(angle_of_inclination) * horizontal_direction
 67    - np.sin(angle_of_inclination) * vertical_direction
 68)
 69
 70butterfly_rod = CosseratRod.straight_rod(
 71    n_elem,
 72    start=origin.reshape(3),
 73    direction=np.array([0.0, 0.0, 1.0]),
 74    normal=normal,
 75    base_length=total_length,
 76    base_radius=base_radius,
 77    density=density,
 78    nu=nu,
 79    youngs_modulus=youngs_modulus,
 80    poisson_ratio=poisson_ratio,
 81    position=positions,
 82)
 83
 84butterfly_sim.append(butterfly_rod)
 85
 86# Add call backs
 87class VelocityCallBack(CallBackBaseClass):
 88    """
 89    Call back function for continuum snake
 90    """
 91
 92    def __init__(self, step_skip: int, callback_params: dict):
 93        CallBackBaseClass.__init__(self)
 94        self.every = step_skip
 95        self.callback_params = callback_params
 96
 97    def make_callback(self, system, time, current_step: int):
 98
 99        if current_step % self.every == 0:
100
101            self.callback_params["time"].append(time)
102            # Collect x
103            self.callback_params["position"].append(system.position_collection.copy())
104            # Collect energies as well
105            self.callback_params["te"].append(system.compute_translational_energy())
106            self.callback_params["re"].append(system.compute_rotational_energy())
107            self.callback_params["se"].append(system.compute_shear_energy())
108            self.callback_params["be"].append(system.compute_bending_energy())
109            return
110
111
112recorded_history = defaultdict(list)
113# initially record history
114recorded_history["time"].append(0.0)
115recorded_history["position"].append(butterfly_rod.position_collection.copy())
116recorded_history["te"].append(butterfly_rod.compute_translational_energy())
117recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
118recorded_history["se"].append(butterfly_rod.compute_shear_energy())
119recorded_history["be"].append(butterfly_rod.compute_bending_energy())
120
121butterfly_sim.collect_diagnostics(butterfly_rod).using(
122    VelocityCallBack, step_skip=100, callback_params=recorded_history
123)
124
125
126butterfly_sim.finalize()
127timestepper = PositionVerlet()
128# timestepper = PEFRL()
129
130dt = 0.01 * dl
131total_steps = int(final_time / dt)
132print("Total steps", total_steps)
133integrate(timestepper, butterfly_sim, final_time, total_steps)
134
135if PLOT_FIGURE:
136    # Plot the histories
137    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
138    ax = fig.add_subplot(111)
139    positions = recorded_history["position"]
140    # record first position
141    first_position = positions.pop(0)
142    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
143    n_positions = len(positions)
144    for i, pos in enumerate(positions):
145        alpha = np.exp(i / n_positions - 1)
146        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
147    # final position is also separate
148    last_position = positions.pop()
149    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
150    # don't block
151    fig.show()
152
153    # Plot the energies
154    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
155    energy_ax = energy_fig.add_subplot(111)
156    times = np.asarray(recorded_history["time"])
157    te = np.asarray(recorded_history["te"])
158    re = np.asarray(recorded_history["re"])
159    be = np.asarray(recorded_history["be"])
160    se = np.asarray(recorded_history["se"])
161
162    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
163    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
164    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
165    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
166    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
167    energy_ax.legend()
168    # don't block
169    energy_fig.show()
170
171    if SAVE_FIGURE:
172        fig.savefig("butterfly.png")
173        energy_fig.savefig("energies.png")
174
175    plt.show()
176
177if SAVE_RESULTS:
178    import pickle
179
180    filename = "butterfly_data.dat"
181    file = open(filename, "wb")
182    pickle.dump(butterfly_rod, file)
183    file.close()

Helical Buckling#

 1__doc__ = """Helical buckling validation case, for detailed explanation refer to 
 2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
 3
 4import numpy as np
 5import sys
 6
 7# FIXME without appending sys.path make it more generic
 8sys.path.append("../../")
 9from elastica import *
10from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
11    plot_helicalbuckling,
12)
13
14
15class HelicalBucklingSimulator(BaseSystemCollection, Constraints, Forcing):
16    pass
17
18
19helicalbuckling_sim = HelicalBucklingSimulator()
20
21# Options
22PLOT_FIGURE = True
23SAVE_FIGURE = True
24SAVE_RESULTS = False
25
26# setting up test params
27n_elem = 100
28start = np.zeros((3,))
29direction = np.array([0.0, 0.0, 1.0])
30normal = np.array([0.0, 1.0, 0.0])
31base_length = 100.0
32base_radius = 0.35
33base_area = np.pi * base_radius ** 2
34density = 1.0 / (base_area)
35nu = 0.01
36E = 1e6
37slack = 3
38number_of_rotations = 27
39# For shear modulus of 1e5, nu is 99!
40poisson_ratio = 9
41shear_modulus = E / (poisson_ratio + 1.0)
42shear_matrix = np.repeat(
43    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
44)
45temp_bend_matrix = np.zeros((3, 3))
46np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
47bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
48
49shearable_rod = CosseratRod.straight_rod(
50    n_elem,
51    start,
52    direction,
53    normal,
54    base_length,
55    base_radius,
56    density,
57    nu,
58    E,
59    shear_modulus=shear_modulus,
60)
61# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
62
63shearable_rod.shear_matrix = shear_matrix
64shearable_rod.bend_matrix = bend_matrix
65
66
67helicalbuckling_sim.append(shearable_rod)
68helicalbuckling_sim.constrain(shearable_rod).using(
69    HelicalBucklingBC,
70    constrained_position_idx=(0, -1),
71    constrained_director_idx=(0, -1),
72    twisting_time=500,
73    slack=slack,
74    number_of_rotations=number_of_rotations,
75)
76
77helicalbuckling_sim.finalize()
78timestepper = PositionVerlet()
79shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
80# timestepper = PEFRL()
81
82final_time = 10500.0
83dl = base_length / n_elem
84dt = 1e-3 * dl
85total_steps = int(final_time / dt)
86print("Total steps", total_steps)
87integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
88
89if PLOT_FIGURE:
90    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
91
92if SAVE_RESULTS:
93    import pickle
94
95    filename = "HelicalBuckling_data.dat"
96    file = open(filename, "wb")
97    pickle.dump(shearable_rod, file)
98    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import sys
  4
  5sys.path.append("../../")
  6from elastica import *
  7
  8from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
  9    plot_snake_velocity,
 10    plot_video,
 11    compute_projected_velocity,
 12    plot_curvature,
 13)
 14
 15
 16class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
 17    pass
 18
 19
 20def run_snake(
 21    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
 22):
 23    # Initialize the simulation class
 24    snake_sim = SnakeSimulator()
 25
 26    # Simulation parameters
 27    period = 2
 28    final_time = (11.0 + 0.01) * period
 29    time_step = 8e-6
 30    total_steps = int(final_time / time_step)
 31    rendering_fps = 60
 32    step_skip = int(1.0 / (rendering_fps * time_step))
 33
 34    # setting up test params
 35    n_elem = 50
 36    start = np.zeros((3,))
 37    direction = np.array([0.0, 0.0, 1.0])
 38    normal = np.array([0.0, 1.0, 0.0])
 39    base_length = 0.35
 40    base_radius = base_length * 0.011
 41    density = 1000
 42    nu = 1e-4
 43    E = 1e6
 44    poisson_ratio = 0.5
 45    shear_modulus = E / (poisson_ratio + 1.0)
 46
 47    shearable_rod = CosseratRod.straight_rod(
 48        n_elem,
 49        start,
 50        direction,
 51        normal,
 52        base_length,
 53        base_radius,
 54        density,
 55        nu,
 56        E,
 57        shear_modulus=shear_modulus,
 58    )
 59
 60    snake_sim.append(shearable_rod)
 61
 62    # Add gravitational forces
 63    gravitational_acc = -9.80665
 64    snake_sim.add_forcing_to(shearable_rod).using(
 65        GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 66    )
 67
 68    # Add muscle torques
 69    wave_length = b_coeff[-1]
 70    snake_sim.add_forcing_to(shearable_rod).using(
 71        MuscleTorques,
 72        base_length=base_length,
 73        b_coeff=b_coeff[:-1],
 74        period=period,
 75        wave_number=2.0 * np.pi / (wave_length),
 76        phase_shift=0.0,
 77        rest_lengths=shearable_rod.rest_lengths,
 78        ramp_up_time=period,
 79        direction=normal,
 80        with_spline=True,
 81    )
 82
 83    # Add friction forces
 84    origin_plane = np.array([0.0, -base_radius, 0.0])
 85    normal_plane = normal
 86    slip_velocity_tol = 1e-8
 87    froude = 0.1
 88    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 89    kinetic_mu_array = np.array(
 90        [mu, 1.5 * mu, 2.0 * mu]
 91    )  # [forward, backward, sideways]
 92    static_mu_array = np.zeros(kinetic_mu_array.shape)
 93    snake_sim.add_forcing_to(shearable_rod).using(
 94        AnisotropicFrictionalPlane,
 95        k=1.0,
 96        nu=1e-6,
 97        plane_origin=origin_plane,
 98        plane_normal=normal_plane,
 99        slip_velocity_tol=slip_velocity_tol,
100        static_mu_array=static_mu_array,
101        kinetic_mu_array=kinetic_mu_array,
102    )
103
104    # Add call backs
105    class ContinuumSnakeCallBack(CallBackBaseClass):
106        """
107        Call back function for continuum snake
108        """
109
110        def __init__(self, step_skip: int, callback_params: dict):
111            CallBackBaseClass.__init__(self)
112            self.every = step_skip
113            self.callback_params = callback_params
114
115        def make_callback(self, system, time, current_step: int):
116
117            if current_step % self.every == 0:
118
119                self.callback_params["time"].append(time)
120                self.callback_params["step"].append(current_step)
121                self.callback_params["position"].append(
122                    system.position_collection.copy()
123                )
124                self.callback_params["velocity"].append(
125                    system.velocity_collection.copy()
126                )
127                self.callback_params["avg_velocity"].append(
128                    system.compute_velocity_center_of_mass()
129                )
130
131                self.callback_params["center_of_mass"].append(
132                    system.compute_position_center_of_mass()
133                )
134                self.callback_params["curvature"].append(system.kappa.copy())
135
136                return
137
138    pp_list = defaultdict(list)
139    snake_sim.collect_diagnostics(shearable_rod).using(
140        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
141    )
142
143    snake_sim.finalize()
144
145    timestepper = PositionVerlet()
146    integrate(timestepper, snake_sim, final_time, total_steps)
147
148    if PLOT_FIGURE:
149        filename_plot = "continuum_snake_velocity.png"
150        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
151        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
152
153        if SAVE_VIDEO:
154            filename_video = "continuum_snake.mp4"
155            plot_video(
156                pp_list,
157                video_name=filename_video,
158                fps=rendering_fps,
159                xlim=(0, 4),
160                ylim=(-1, 1),
161            )
162
163    if SAVE_RESULTS:
164        import pickle
165
166        filename = "continuum_snake.dat"
167        file = open(filename, "wb")
168        pickle.dump(pp_list, file)
169        file.close()
170
171    # Compute the average forward velocity. These will be used for optimization.
172    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
173
174    return avg_forward, avg_lateral, pp_list
175
176
177if __name__ == "__main__":
178
179    # Options
180    PLOT_FIGURE = True
181    SAVE_FIGURE = True
182    SAVE_VIDEO = True
183    SAVE_RESULTS = False
184    CMA_OPTION = False
185
186    if CMA_OPTION:
187        import cma
188
189        SAVE_OPTIMIZED_COEFFICIENTS = False
190
191        def optimize_snake(spline_coefficient):
192            [avg_forward, _, _] = run_snake(
193                spline_coefficient,
194                PLOT_FIGURE=False,
195                SAVE_FIGURE=False,
196                SAVE_VIDEO=False,
197                SAVE_RESULTS=False,
198            )
199            return -avg_forward
200
201        # Optimize snake for forward velocity. In cma.fmin first input is function
202        # to be optimized, second input is initial guess for coefficients you are optimizing
203        # for and third input is standard deviation you initially set.
204        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
205
206        # Save the optimized coefficients to a file
207        filename_data = "optimized_coefficients.txt"
208        if SAVE_OPTIMIZED_COEFFICIENTS:
209            assert filename_data != "", "provide a file name for coefficients"
210            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
211
212    else:
213        # Add muscle forces on the rod
214        if os.path.exists("optimized_coefficients.txt"):
215            t_coeff_optimized = np.genfromtxt(
216                "optimized_coefficients.txt", delimiter=","
217            )
218        else:
219            wave_length = 1.0
220            t_coeff_optimized = np.array(
221                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
222            )
223            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
224
225        # run the simulation
226        [avg_forward, avg_lateral, pp_list] = run_snake(
227            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
228        )
229
230        print("average forward velocity:", avg_forward)
231        print("average forward lateral:", avg_lateral)