Example Cases
Contents
Example Cases#
Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.
Axial Stretching#
1""" Axial stretching test-case
2
3 Assume we have a rod lying aligned in the x-direction, with high internal
4 damping.
5
6 We fix one end (say, the left end) of the rod to a wall. On the right
7 end we apply a force directed axially pulling the rods tip. Linear
8 theory (assuming small displacements) predict that the net displacement
9 experienced by the rod tip is Δx = FL/AE where the symbols carry their
10 usual meaning (the rod is just a linear spring). We compare our results
11 with the above result.
12
13 We can "improve" the theory by having a better estimate for the rod's
14 spring constant by assuming that it equilibriates under the new position,
15 with
16 Δx = F * (L + Δx)/ (A * E)
17 which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
18 this position.
19
20 Note that if the damping is not high, the rod oscillates about the eventual
21 resting position (and this agrees with the theoretical predictions without
22 any damping : we should see the rod oscillating simple-harmonically in time).
23
24 isort:skip_file
25"""
26# FIXME without appending sys.path make it more generic
27import sys
28
29sys.path.append("../../") # isort:skip
30
31# from collections import defaultdict
32
33import numpy as np
34from matplotlib import pyplot as plt
35
36from elastica import *
37
38
39class StretchingBeamSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
40 pass
41
42
43stretch_sim = StretchingBeamSimulator()
44final_time = 20.0
45
46# Options
47PLOT_FIGURE = True
48SAVE_FIGURE = False
49SAVE_RESULTS = False
50
51# setting up test params
52n_elem = 19
53start = np.zeros((3,))
54direction = np.array([1.0, 0.0, 0.0])
55normal = np.array([0.0, 1.0, 0.0])
56base_length = 1.0
57base_radius = 0.025
58base_area = np.pi * base_radius ** 2
59density = 1000
60nu = 2.0
61youngs_modulus = 1e4
62# For shear modulus of 1e4, nu is 99!
63poisson_ratio = 0.5
64shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
65
66stretchable_rod = CosseratRod.straight_rod(
67 n_elem,
68 start,
69 direction,
70 normal,
71 base_length,
72 base_radius,
73 density,
74 nu,
75 youngs_modulus,
76 shear_modulus=shear_modulus,
77)
78
79stretch_sim.append(stretchable_rod)
80stretch_sim.constrain(stretchable_rod).using(
81 OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
82)
83
84end_force_x = 1.0
85end_force = np.array([end_force_x, 0.0, 0.0])
86stretch_sim.add_forcing_to(stretchable_rod).using(
87 EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
88)
89
90# Add call backs
91class AxialStretchingCallBack(CallBackBaseClass):
92 """
93 Call back function for continuum snake
94 """
95
96 def __init__(self, step_skip: int, callback_params: dict):
97 CallBackBaseClass.__init__(self)
98 self.every = step_skip
99 self.callback_params = callback_params
100
101 def make_callback(self, system, time, current_step: int):
102
103 if current_step % self.every == 0:
104
105 self.callback_params["time"].append(time)
106 # Collect only x
107 self.callback_params["position"].append(
108 system.position_collection[0, -1].copy()
109 )
110 return
111
112
113recorded_history = defaultdict(list)
114stretch_sim.collect_diagnostics(stretchable_rod).using(
115 AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
116)
117
118stretch_sim.finalize()
119timestepper = PositionVerlet()
120# timestepper = PEFRL()
121
122dl = base_length / n_elem
123dt = 0.01 * dl
124total_steps = int(final_time / dt)
125print("Total steps", total_steps)
126integrate(timestepper, stretch_sim, final_time, total_steps)
127
128if PLOT_FIGURE:
129 # First-order theory with base-length
130 expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
131 # First-order theory with modified-length, gives better estimates
132 expected_tip_disp_improved = (
133 end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
134 )
135
136 fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
137 ax = fig.add_subplot(111)
138 ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
139 ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
140 ax.hlines(
141 base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
142 )
143 if SAVE_FIGURE:
144 fig.savefig("axial_stretching.pdf")
145 plt.show()
146
147if SAVE_RESULTS:
148 import pickle
149
150 filename = "axial_stretching_data.dat"
151 file = open(filename, "wb")
152 pickle.dump(stretchable_rod, file)
153 file.close()
Timoshenko#
1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018 section 3.4.3 """
3
4import numpy as np
5import sys
6
7# FIXME without appending sys.path make it more generic
8sys.path.append("../../")
9from elastica import *
10from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
11
12
13class TimoshenkoBeamSimulator(BaseSystemCollection, Constraints, Forcing):
14 pass
15
16
17timoshenko_sim = TimoshenkoBeamSimulator()
18final_time = 5000
19
20# Options
21PLOT_FIGURE = True
22SAVE_FIGURE = False
23SAVE_RESULTS = False
24ADD_UNSHEARABLE_ROD = True
25
26# setting up test params
27n_elem = 100
28start = np.zeros((3,))
29direction = np.array([0.0, 0.0, 1.0])
30normal = np.array([0.0, 1.0, 0.0])
31base_length = 3.0
32base_radius = 0.25
33base_area = np.pi * base_radius ** 2
34density = 5000
35nu = 0.1
36E = 1e6
37# For shear modulus of 1e4, nu is 99!
38poisson_ratio = 99
39shear_modulus = E / (poisson_ratio + 1.0)
40
41shearable_rod = CosseratRod.straight_rod(
42 n_elem,
43 start,
44 direction,
45 normal,
46 base_length,
47 base_radius,
48 density,
49 nu,
50 E,
51 shear_modulus=shear_modulus,
52)
53
54timoshenko_sim.append(shearable_rod)
55timoshenko_sim.constrain(shearable_rod).using(
56 OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
57)
58
59end_force = np.array([-15.0, 0.0, 0.0])
60timoshenko_sim.add_forcing_to(shearable_rod).using(
61 EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
62)
63
64
65if ADD_UNSHEARABLE_ROD:
66 # Start into the plane
67 unshearable_start = np.array([0.0, -1.0, 0.0])
68 shear_modulus = E / (-0.7 + 1.0)
69 unshearable_rod = CosseratRod.straight_rod(
70 n_elem,
71 unshearable_start,
72 direction,
73 normal,
74 base_length,
75 base_radius,
76 density,
77 nu,
78 E,
79 # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
80 shear_modulus=shear_modulus,
81 )
82
83 timoshenko_sim.append(unshearable_rod)
84 timoshenko_sim.constrain(unshearable_rod).using(
85 OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
86 )
87 timoshenko_sim.add_forcing_to(unshearable_rod).using(
88 EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
89 )
90
91timoshenko_sim.finalize()
92timestepper = PositionVerlet()
93# timestepper = PEFRL()
94
95dl = base_length / n_elem
96dt = 0.01 * dl
97total_steps = int(final_time / dt)
98print("Total steps", total_steps)
99integrate(timestepper, timoshenko_sim, final_time, total_steps)
100
101if PLOT_FIGURE:
102 plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
103
104if SAVE_RESULTS:
105 import pickle
106
107 filename = "Timoshenko_beam_data.dat"
108 file = open(filename, "wb")
109 pickle.dump(shearable_rod, file)
110 file.close()
Butterfly#
1# FIXME without appending sys.path make it more generic
2import sys
3
4sys.path.append("../")
5sys.path.append("../../")
6
7# from collections import defaultdict
8import numpy as np
9from matplotlib import pyplot as plt
10from matplotlib.colors import to_rgb
11
12
13from elastica import *
14from elastica.utils import MaxDimension
15
16
17class ButterflySimulator(BaseSystemCollection, CallBacks):
18 pass
19
20
21butterfly_sim = ButterflySimulator()
22final_time = 40.0
23
24# Options
25PLOT_FIGURE = True
26SAVE_FIGURE = True
27SAVE_RESULTS = True
28ADD_UNSHEARABLE_ROD = False
29
30# setting up test params
31# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
32n_elem = 4 # Change based on requirements, but be careful
33n_elem += n_elem % 2
34half_n_elem = n_elem // 2
35
36origin = np.zeros((3, 1))
37angle_of_inclination = np.deg2rad(45.0)
38
39# in-plane
40horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
41vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
42
43# out-of-plane
44normal = np.array([0.0, 1.0, 0.0])
45
46total_length = 3.0
47base_radius = 0.25
48base_area = np.pi * base_radius ** 2
49density = 5000
50nu = 0.0
51youngs_modulus = 1e4
52poisson_ratio = 0.5
53shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
54
55positions = np.empty((MaxDimension.value(), n_elem + 1))
56dl = total_length / n_elem
57
58# First half of positions stem from slope angle_of_inclination
59first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
60positions[..., : half_n_elem + 1] = origin + dl * first_half * (
61 np.cos(angle_of_inclination) * horizontal_direction
62 + np.sin(angle_of_inclination) * vertical_direction
63)
64positions[..., half_n_elem:] = positions[
65 ..., half_n_elem : half_n_elem + 1
66] + dl * first_half * (
67 np.cos(angle_of_inclination) * horizontal_direction
68 - np.sin(angle_of_inclination) * vertical_direction
69)
70
71butterfly_rod = CosseratRod.straight_rod(
72 n_elem,
73 start=origin.reshape(3),
74 direction=np.array([0.0, 0.0, 1.0]),
75 normal=normal,
76 base_length=total_length,
77 base_radius=base_radius,
78 density=density,
79 nu=nu,
80 youngs_modulus=youngs_modulus,
81 shear_modulus=shear_modulus,
82 position=positions,
83)
84
85butterfly_sim.append(butterfly_rod)
86
87# Add call backs
88class VelocityCallBack(CallBackBaseClass):
89 """
90 Call back function for continuum snake
91 """
92
93 def __init__(self, step_skip: int, callback_params: dict):
94 CallBackBaseClass.__init__(self)
95 self.every = step_skip
96 self.callback_params = callback_params
97
98 def make_callback(self, system, time, current_step: int):
99
100 if current_step % self.every == 0:
101
102 self.callback_params["time"].append(time)
103 # Collect x
104 self.callback_params["position"].append(system.position_collection.copy())
105 # Collect energies as well
106 self.callback_params["te"].append(system.compute_translational_energy())
107 self.callback_params["re"].append(system.compute_rotational_energy())
108 self.callback_params["se"].append(system.compute_shear_energy())
109 self.callback_params["be"].append(system.compute_bending_energy())
110 return
111
112
113recorded_history = defaultdict(list)
114# initially record history
115recorded_history["time"].append(0.0)
116recorded_history["position"].append(butterfly_rod.position_collection.copy())
117recorded_history["te"].append(butterfly_rod.compute_translational_energy())
118recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
119recorded_history["se"].append(butterfly_rod.compute_shear_energy())
120recorded_history["be"].append(butterfly_rod.compute_bending_energy())
121
122butterfly_sim.collect_diagnostics(butterfly_rod).using(
123 VelocityCallBack, step_skip=100, callback_params=recorded_history
124)
125
126
127butterfly_sim.finalize()
128timestepper = PositionVerlet()
129# timestepper = PEFRL()
130
131dt = 0.01 * dl
132total_steps = int(final_time / dt)
133print("Total steps", total_steps)
134integrate(timestepper, butterfly_sim, final_time, total_steps)
135
136if PLOT_FIGURE:
137 # Plot the histories
138 fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
139 ax = fig.add_subplot(111)
140 positions = recorded_history["position"]
141 # record first position
142 first_position = positions.pop(0)
143 ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
144 n_positions = len(positions)
145 for i, pos in enumerate(positions):
146 alpha = np.exp(i / n_positions - 1)
147 ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
148 # final position is also separate
149 last_position = positions.pop()
150 ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
151 # don't block
152 fig.show()
153
154 # Plot the energies
155 energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
156 energy_ax = energy_fig.add_subplot(111)
157 times = np.asarray(recorded_history["time"])
158 te = np.asarray(recorded_history["te"])
159 re = np.asarray(recorded_history["re"])
160 be = np.asarray(recorded_history["be"])
161 se = np.asarray(recorded_history["se"])
162
163 energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
164 energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
165 energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
166 energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
167 energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
168 energy_ax.legend()
169 # don't block
170 energy_fig.show()
171
172 if SAVE_FIGURE:
173 fig.savefig("butterfly.png")
174 energy_fig.savefig("energies.png")
175
176 plt.show()
177
178if SAVE_RESULTS:
179 import pickle
180
181 filename = "butterfly_data.dat"
182 file = open(filename, "wb")
183 pickle.dump(butterfly_rod, file)
184 file.close()
Helical Buckling#
1__doc__ = """Helical buckling validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018 section 3.4.1 """
3
4import numpy as np
5import sys
6
7# FIXME without appending sys.path make it more generic
8sys.path.append("../../")
9from elastica import *
10from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
11 plot_helicalbuckling,
12)
13
14
15class HelicalBucklingSimulator(BaseSystemCollection, Constraints, Forcing):
16 pass
17
18
19helicalbuckling_sim = HelicalBucklingSimulator()
20
21# Options
22PLOT_FIGURE = True
23SAVE_FIGURE = True
24SAVE_RESULTS = False
25
26# setting up test params
27n_elem = 100
28start = np.zeros((3,))
29direction = np.array([0.0, 0.0, 1.0])
30normal = np.array([0.0, 1.0, 0.0])
31base_length = 100.0
32base_radius = 0.35
33base_area = np.pi * base_radius ** 2
34density = 1.0 / (base_area)
35nu = 0.01
36E = 1e6
37slack = 3
38number_of_rotations = 27
39# For shear modulus of 1e5, nu is 99!
40poisson_ratio = 9
41shear_modulus = E / (poisson_ratio + 1.0)
42shear_matrix = np.repeat(
43 shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
44)
45temp_bend_matrix = np.zeros((3, 3))
46np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
47bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
48
49shearable_rod = CosseratRod.straight_rod(
50 n_elem,
51 start,
52 direction,
53 normal,
54 base_length,
55 base_radius,
56 density,
57 nu,
58 E,
59 shear_modulus=shear_modulus,
60)
61# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
62
63shearable_rod.shear_matrix = shear_matrix
64shearable_rod.bend_matrix = bend_matrix
65
66
67helicalbuckling_sim.append(shearable_rod)
68helicalbuckling_sim.constrain(shearable_rod).using(
69 HelicalBucklingBC,
70 constrained_position_idx=(0, -1),
71 constrained_director_idx=(0, -1),
72 twisting_time=500,
73 slack=slack,
74 number_of_rotations=number_of_rotations,
75)
76
77helicalbuckling_sim.finalize()
78timestepper = PositionVerlet()
79shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
80# timestepper = PEFRL()
81
82final_time = 10500.0
83dl = base_length / n_elem
84dt = 1e-3 * dl
85total_steps = int(final_time / dt)
86print("Total steps", total_steps)
87integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
88
89if PLOT_FIGURE:
90 plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
91
92if SAVE_RESULTS:
93 import pickle
94
95 filename = "HelicalBuckling_data.dat"
96 file = open(filename, "wb")
97 pickle.dump(shearable_rod, file)
98 file.close()
Continuum Snake#
1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
2
3import sys
4import os
5import numpy as np
6
7sys.path.append("../../")
8from elastica import *
9
10from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
11 plot_snake_velocity,
12 plot_video,
13 compute_projected_velocity,
14 plot_curvature,
15)
16
17
18class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
19 pass
20
21
22def run_snake(
23 b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
24):
25 # Initialize the simulation class
26 snake_sim = SnakeSimulator()
27
28 # Simulation parameters
29 period = 2
30 final_time = (11.0 + 0.01) * period
31 time_step = 8e-6
32 total_steps = int(final_time / time_step)
33 rendering_fps = 60
34 step_skip = int(1.0 / (rendering_fps * time_step))
35
36 # setting up test params
37 n_elem = 50
38 start = np.zeros((3,))
39 direction = np.array([0.0, 0.0, 1.0])
40 normal = np.array([0.0, 1.0, 0.0])
41 base_length = 0.35
42 base_radius = base_length * 0.011
43 density = 1000
44 nu = 1e-4
45 E = 1e6
46 poisson_ratio = 0.5
47 shear_modulus = E / (poisson_ratio + 1.0)
48
49 shearable_rod = CosseratRod.straight_rod(
50 n_elem,
51 start,
52 direction,
53 normal,
54 base_length,
55 base_radius,
56 density,
57 nu,
58 E,
59 shear_modulus=shear_modulus,
60 )
61
62 snake_sim.append(shearable_rod)
63
64 # Add gravitational forces
65 gravitational_acc = -9.80665
66 snake_sim.add_forcing_to(shearable_rod).using(
67 GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
68 )
69
70 # Add muscle torques
71 wave_length = b_coeff[-1]
72 snake_sim.add_forcing_to(shearable_rod).using(
73 MuscleTorques,
74 base_length=base_length,
75 b_coeff=b_coeff[:-1],
76 period=period,
77 wave_number=2.0 * np.pi / (wave_length),
78 phase_shift=0.0,
79 rest_lengths=shearable_rod.rest_lengths,
80 ramp_up_time=period,
81 direction=normal,
82 with_spline=True,
83 )
84
85 # Add friction forces
86 origin_plane = np.array([0.0, -base_radius, 0.0])
87 normal_plane = normal
88 slip_velocity_tol = 1e-8
89 froude = 0.1
90 mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
91 kinetic_mu_array = np.array(
92 [mu, 1.5 * mu, 2.0 * mu]
93 ) # [forward, backward, sideways]
94 static_mu_array = np.zeros(kinetic_mu_array.shape)
95 snake_sim.add_forcing_to(shearable_rod).using(
96 AnisotropicFrictionalPlane,
97 k=1.0,
98 nu=1e-6,
99 plane_origin=origin_plane,
100 plane_normal=normal_plane,
101 slip_velocity_tol=slip_velocity_tol,
102 static_mu_array=static_mu_array,
103 kinetic_mu_array=kinetic_mu_array,
104 )
105
106 # Add call backs
107 class ContinuumSnakeCallBack(CallBackBaseClass):
108 """
109 Call back function for continuum snake
110 """
111
112 def __init__(self, step_skip: int, callback_params: dict):
113 CallBackBaseClass.__init__(self)
114 self.every = step_skip
115 self.callback_params = callback_params
116
117 def make_callback(self, system, time, current_step: int):
118
119 if current_step % self.every == 0:
120
121 self.callback_params["time"].append(time)
122 self.callback_params["step"].append(current_step)
123 self.callback_params["position"].append(
124 system.position_collection.copy()
125 )
126 self.callback_params["velocity"].append(
127 system.velocity_collection.copy()
128 )
129 self.callback_params["avg_velocity"].append(
130 system.compute_velocity_center_of_mass()
131 )
132
133 self.callback_params["center_of_mass"].append(
134 system.compute_position_center_of_mass()
135 )
136 self.callback_params["curvature"].append(system.kappa.copy())
137
138 return
139
140 pp_list = defaultdict(list)
141 snake_sim.collect_diagnostics(shearable_rod).using(
142 ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
143 )
144
145 snake_sim.finalize()
146
147 timestepper = PositionVerlet()
148 integrate(timestepper, snake_sim, final_time, total_steps)
149
150 if PLOT_FIGURE:
151 filename_plot = "continuum_snake_velocity.png"
152 plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
153 plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
154
155 if SAVE_VIDEO:
156 filename_video = "continuum_snake.mp4"
157 plot_video(
158 pp_list,
159 video_name=filename_video,
160 fps=rendering_fps,
161 xlim=(0, 4),
162 ylim=(-1, 1),
163 )
164
165 if SAVE_RESULTS:
166 import pickle
167
168 filename = "continuum_snake.dat"
169 file = open(filename, "wb")
170 pickle.dump(pp_list, file)
171 file.close()
172
173 # Compute the average forward velocity. These will be used for optimization.
174 [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
175
176 return avg_forward, avg_lateral, pp_list
177
178
179if __name__ == "__main__":
180
181 # Options
182 PLOT_FIGURE = True
183 SAVE_FIGURE = True
184 SAVE_VIDEO = True
185 SAVE_RESULTS = False
186 CMA_OPTION = False
187
188 if CMA_OPTION:
189 import cma
190
191 SAVE_OPTIMIZED_COEFFICIENTS = False
192
193 def optimize_snake(spline_coefficient):
194 [avg_forward, _, _] = run_snake(
195 spline_coefficient,
196 PLOT_FIGURE=False,
197 SAVE_FIGURE=False,
198 SAVE_VIDEO=False,
199 SAVE_RESULTS=False,
200 )
201 return -avg_forward
202
203 # Optimize snake for forward velocity. In cma.fmin first input is function
204 # to be optimized, second input is initial guess for coefficients you are optimizing
205 # for and third input is standard deviation you initially set.
206 optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
207
208 # Save the optimized coefficients to a file
209 filename_data = "optimized_coefficients.txt"
210 if SAVE_OPTIMIZED_COEFFICIENTS:
211 assert filename_data != "", "provide a file name for coefficients"
212 np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
213
214 else:
215 # Add muscle forces on the rod
216 if os.path.exists("optimized_coefficients.txt"):
217 t_coeff_optimized = np.genfromtxt(
218 "optimized_coefficients.txt", delimiter=","
219 )
220 else:
221 wave_length = 1.0
222 t_coeff_optimized = np.array(
223 [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
224 )
225 t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
226
227 # run the simulation
228 [avg_forward, avg_lateral, pp_list] = run_snake(
229 t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
230 )
231
232 print("average forward velocity:", avg_forward)
233 print("average forward lateral:", avg_lateral)