Example Cases#
Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.
Axial Stretching#
1""" Axial stretching test-case
2
3 Assume we have a rod lying aligned in the x-direction, with high internal
4 damping.
5
6 We fix one end (say, the left end) of the rod to a wall. On the right
7 end we apply a force directed axially pulling the rods tip. Linear
8 theory (assuming small displacements) predict that the net displacement
9 experienced by the rod tip is Δx = FL/AE where the symbols carry their
10 usual meaning (the rod is just a linear spring). We compare our results
11 with the above result.
12
13 We can "improve" the theory by having a better estimate for the rod's
14 spring constant by assuming that it equilibriates under the new position,
15 with
16 Δx = F * (L + Δx)/ (A * E)
17 which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
18 this position.
19
20 Note that if the damping is not high, the rod oscillates about the eventual
21 resting position (and this agrees with the theoretical predictions without
22 any damping : we should see the rod oscillating simple-harmonically in time).
23
24 isort:skip_file
25"""
26import numpy as np
27from matplotlib import pyplot as plt
28
29import elastica as ea
30
31
32class StretchingBeamSimulator(
33 ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
34):
35 pass
36
37
38stretch_sim = StretchingBeamSimulator()
39final_time = 200.0
40
41# Options
42PLOT_FIGURE = True
43SAVE_FIGURE = False
44SAVE_RESULTS = False
45
46# setting up test params
47n_elem = 19
48start = np.zeros((3,))
49direction = np.array([1.0, 0.0, 0.0])
50normal = np.array([0.0, 1.0, 0.0])
51base_length = 1.0
52base_radius = 0.025
53base_area = np.pi * base_radius ** 2
54density = 1000
55youngs_modulus = 1e4
56# For shear modulus of 1e4, nu is 99!
57poisson_ratio = 0.5
58shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
59
60stretchable_rod = ea.CosseratRod.straight_rod(
61 n_elem,
62 start,
63 direction,
64 normal,
65 base_length,
66 base_radius,
67 density,
68 youngs_modulus=youngs_modulus,
69 shear_modulus=shear_modulus,
70)
71
72stretch_sim.append(stretchable_rod)
73stretch_sim.constrain(stretchable_rod).using(
74 ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
75)
76
77end_force_x = 1.0
78end_force = np.array([end_force_x, 0.0, 0.0])
79stretch_sim.add_forcing_to(stretchable_rod).using(
80 ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
81)
82
83# add damping
84dl = base_length / n_elem
85dt = 0.1 * dl
86damping_constant = 0.1
87stretch_sim.dampen(stretchable_rod).using(
88 ea.AnalyticalLinearDamper,
89 damping_constant=damping_constant,
90 time_step=dt,
91)
92
93# Add call backs
94class AxialStretchingCallBack(ea.CallBackBaseClass):
95 """
96 Tracks the velocity norms of the rod
97 """
98
99 def __init__(self, step_skip: int, callback_params: dict):
100 ea.CallBackBaseClass.__init__(self)
101 self.every = step_skip
102 self.callback_params = callback_params
103
104 def make_callback(self, system, time, current_step: int):
105
106 if current_step % self.every == 0:
107
108 self.callback_params["time"].append(time)
109 # Collect only x
110 self.callback_params["position"].append(
111 system.position_collection[0, -1].copy()
112 )
113 self.callback_params["velocity_norms"].append(
114 np.linalg.norm(system.velocity_collection.copy())
115 )
116 return
117
118
119recorded_history = ea.defaultdict(list)
120stretch_sim.collect_diagnostics(stretchable_rod).using(
121 AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
122)
123
124stretch_sim.finalize()
125timestepper = ea.PositionVerlet()
126# timestepper = PEFRL()
127
128total_steps = int(final_time / dt)
129print("Total steps", total_steps)
130ea.integrate(timestepper, stretch_sim, final_time, total_steps)
131
132if PLOT_FIGURE:
133 # First-order theory with base-length
134 expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
135 # First-order theory with modified-length, gives better estimates
136 expected_tip_disp_improved = (
137 end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
138 )
139
140 fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
141 ax = fig.add_subplot(111)
142 ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
143 ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
144 ax.hlines(
145 base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
146 )
147 if SAVE_FIGURE:
148 fig.savefig("axial_stretching.pdf")
149 plt.show()
150
151if SAVE_RESULTS:
152 import pickle
153
154 filename = "axial_stretching_data.dat"
155 file = open(filename, "wb")
156 pickle.dump(stretchable_rod, file)
157 file.close()
158
159 tv = (
160 np.asarray(recorded_history["time"]),
161 np.asarray(recorded_history["velocity_norms"]),
162 )
163
164 def as_time_series(v):
165 return v.T
166
167 np.savetxt(
168 "velocity_norms.csv",
169 as_time_series(np.stack(tv)),
170 delimiter=",",
171 )
Timoshenko#
1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018 section 3.4.3 """
3
4import numpy as np
5import elastica as ea
6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
7
8
9class TimoshenkoBeamSimulator(
10 ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.CallBacks, ea.Damping
11):
12 pass
13
14
15timoshenko_sim = TimoshenkoBeamSimulator()
16final_time = 5000.0
17
18# Options
19PLOT_FIGURE = True
20SAVE_FIGURE = True
21SAVE_RESULTS = False
22ADD_UNSHEARABLE_ROD = False
23
24# setting up test params
25n_elem = 100
26start = np.zeros((3,))
27direction = np.array([0.0, 0.0, 1.0])
28normal = np.array([0.0, 1.0, 0.0])
29base_length = 3.0
30base_radius = 0.25
31base_area = np.pi * base_radius ** 2
32density = 5000
33nu = 0.1 / 7 / density / base_area
34E = 1e6
35# For shear modulus of 1e4, nu is 99!
36poisson_ratio = 99
37shear_modulus = E / (poisson_ratio + 1.0)
38
39shearable_rod = ea.CosseratRod.straight_rod(
40 n_elem,
41 start,
42 direction,
43 normal,
44 base_length,
45 base_radius,
46 density,
47 youngs_modulus=E,
48 shear_modulus=shear_modulus,
49)
50
51timoshenko_sim.append(shearable_rod)
52# add damping
53dl = base_length / n_elem
54dt = 0.07 * dl
55timoshenko_sim.dampen(shearable_rod).using(
56 ea.AnalyticalLinearDamper,
57 damping_constant=nu,
58 time_step=dt,
59)
60
61timoshenko_sim.constrain(shearable_rod).using(
62 ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
63)
64
65end_force = np.array([-15.0, 0.0, 0.0])
66timoshenko_sim.add_forcing_to(shearable_rod).using(
67 ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
68)
69
70
71if ADD_UNSHEARABLE_ROD:
72 # Start into the plane
73 unshearable_start = np.array([0.0, -1.0, 0.0])
74 shear_modulus = E / (-0.7 + 1.0)
75 unshearable_rod = ea.CosseratRod.straight_rod(
76 n_elem,
77 unshearable_start,
78 direction,
79 normal,
80 base_length,
81 base_radius,
82 density,
83 youngs_modulus=E,
84 # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
85 shear_modulus=shear_modulus,
86 )
87
88 timoshenko_sim.append(unshearable_rod)
89
90 # add damping
91 timoshenko_sim.dampen(unshearable_rod).using(
92 ea.AnalyticalLinearDamper,
93 damping_constant=nu,
94 time_step=dt,
95 )
96 timoshenko_sim.constrain(unshearable_rod).using(
97 ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
98 )
99 timoshenko_sim.add_forcing_to(unshearable_rod).using(
100 ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
101 )
102
103# Add call backs
104class VelocityCallBack(ea.CallBackBaseClass):
105 """
106 Tracks the velocity norms of the rod
107 """
108
109 def __init__(self, step_skip: int, callback_params: dict):
110 ea.CallBackBaseClass.__init__(self)
111 self.every = step_skip
112 self.callback_params = callback_params
113
114 def make_callback(self, system, time, current_step: int):
115
116 if current_step % self.every == 0:
117
118 self.callback_params["time"].append(time)
119 # Collect x
120 self.callback_params["velocity_norms"].append(
121 np.linalg.norm(system.velocity_collection.copy())
122 )
123 return
124
125
126recorded_history = ea.defaultdict(list)
127timoshenko_sim.collect_diagnostics(shearable_rod).using(
128 VelocityCallBack, step_skip=500, callback_params=recorded_history
129)
130
131timoshenko_sim.finalize()
132timestepper = ea.PositionVerlet()
133# timestepper = PEFRL()
134
135total_steps = int(final_time / dt)
136print("Total steps", total_steps)
137ea.integrate(timestepper, timoshenko_sim, final_time, total_steps)
138
139if PLOT_FIGURE:
140 plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
141
142if SAVE_RESULTS:
143 import pickle
144
145 filename = "Timoshenko_beam_data.dat"
146 file = open(filename, "wb")
147 pickle.dump(shearable_rod, file)
148 file.close()
149
150 tv = (
151 np.asarray(recorded_history["time"]),
152 np.asarray(recorded_history["velocity_norms"]),
153 )
154
155 def as_time_series(v):
156 return v.T
157
158 np.savetxt(
159 "velocity_norms.csv",
160 as_time_series(np.stack(tv)),
161 delimiter=",",
162 )
Butterfly#
1import numpy as np
2from matplotlib import pyplot as plt
3from matplotlib.colors import to_rgb
4
5
6import elastica as ea
7from elastica.utils import MaxDimension
8
9
10class ButterflySimulator(ea.BaseSystemCollection, ea.CallBacks):
11 pass
12
13
14butterfly_sim = ButterflySimulator()
15final_time = 40.0
16
17# Options
18PLOT_FIGURE = True
19SAVE_FIGURE = True
20SAVE_RESULTS = True
21ADD_UNSHEARABLE_ROD = False
22
23# setting up test params
24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
25n_elem = 4 # Change based on requirements, but be careful
26n_elem += n_elem % 2
27half_n_elem = n_elem // 2
28
29origin = np.zeros((3, 1))
30angle_of_inclination = np.deg2rad(45.0)
31
32# in-plane
33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
35
36# out-of-plane
37normal = np.array([0.0, 1.0, 0.0])
38
39total_length = 3.0
40base_radius = 0.25
41base_area = np.pi * base_radius ** 2
42density = 5000
43youngs_modulus = 1e4
44poisson_ratio = 0.5
45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
46
47positions = np.empty((MaxDimension.value(), n_elem + 1))
48dl = total_length / n_elem
49
50# First half of positions stem from slope angle_of_inclination
51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
53 np.cos(angle_of_inclination) * horizontal_direction
54 + np.sin(angle_of_inclination) * vertical_direction
55)
56positions[..., half_n_elem:] = positions[
57 ..., half_n_elem : half_n_elem + 1
58] + dl * first_half * (
59 np.cos(angle_of_inclination) * horizontal_direction
60 - np.sin(angle_of_inclination) * vertical_direction
61)
62
63butterfly_rod = ea.CosseratRod.straight_rod(
64 n_elem,
65 start=origin.reshape(3),
66 direction=np.array([0.0, 0.0, 1.0]),
67 normal=normal,
68 base_length=total_length,
69 base_radius=base_radius,
70 density=density,
71 youngs_modulus=youngs_modulus,
72 shear_modulus=shear_modulus,
73 position=positions,
74)
75
76butterfly_sim.append(butterfly_rod)
77
78# Add call backs
79class VelocityCallBack(ea.CallBackBaseClass):
80 """
81 Call back function for continuum snake
82 """
83
84 def __init__(self, step_skip: int, callback_params: dict):
85 ea.CallBackBaseClass.__init__(self)
86 self.every = step_skip
87 self.callback_params = callback_params
88
89 def make_callback(self, system, time, current_step: int):
90
91 if current_step % self.every == 0:
92
93 self.callback_params["time"].append(time)
94 # Collect x
95 self.callback_params["position"].append(system.position_collection.copy())
96 # Collect energies as well
97 self.callback_params["te"].append(system.compute_translational_energy())
98 self.callback_params["re"].append(system.compute_rotational_energy())
99 self.callback_params["se"].append(system.compute_shear_energy())
100 self.callback_params["be"].append(system.compute_bending_energy())
101 return
102
103
104recorded_history = ea.defaultdict(list)
105# initially record history
106recorded_history["time"].append(0.0)
107recorded_history["position"].append(butterfly_rod.position_collection.copy())
108recorded_history["te"].append(butterfly_rod.compute_translational_energy())
109recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
110recorded_history["se"].append(butterfly_rod.compute_shear_energy())
111recorded_history["be"].append(butterfly_rod.compute_bending_energy())
112
113butterfly_sim.collect_diagnostics(butterfly_rod).using(
114 VelocityCallBack, step_skip=100, callback_params=recorded_history
115)
116
117
118butterfly_sim.finalize()
119timestepper = ea.PositionVerlet()
120# timestepper = PEFRL()
121
122dt = 0.01 * dl
123total_steps = int(final_time / dt)
124print("Total steps", total_steps)
125ea.integrate(timestepper, butterfly_sim, final_time, total_steps)
126
127if PLOT_FIGURE:
128 # Plot the histories
129 fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
130 ax = fig.add_subplot(111)
131 positions = recorded_history["position"]
132 # record first position
133 first_position = positions.pop(0)
134 ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
135 n_positions = len(positions)
136 for i, pos in enumerate(positions):
137 alpha = np.exp(i / n_positions - 1)
138 ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
139 # final position is also separate
140 last_position = positions.pop()
141 ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
142 # don't block
143 fig.show()
144
145 # Plot the energies
146 energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
147 energy_ax = energy_fig.add_subplot(111)
148 times = np.asarray(recorded_history["time"])
149 te = np.asarray(recorded_history["te"])
150 re = np.asarray(recorded_history["re"])
151 be = np.asarray(recorded_history["be"])
152 se = np.asarray(recorded_history["se"])
153
154 energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
155 energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
156 energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
157 energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
158 energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
159 energy_ax.legend()
160 # don't block
161 energy_fig.show()
162
163 if SAVE_FIGURE:
164 fig.savefig("butterfly.png")
165 energy_fig.savefig("energies.png")
166
167 plt.show()
168
169if SAVE_RESULTS:
170 import pickle
171
172 filename = "butterfly_data.dat"
173 file = open(filename, "wb")
174 pickle.dump(butterfly_rod, file)
175 file.close()
Helical Buckling#
1__doc__ = """Helical buckling validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018 section 3.4.1 """
3
4import numpy as np
5import elastica as ea
6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
7 plot_helicalbuckling,
8)
9
10
11class HelicalBucklingSimulator(
12 ea.BaseSystemCollection, ea.Constraints, ea.Damping, ea.Forcing
13):
14 pass
15
16
17helicalbuckling_sim = HelicalBucklingSimulator()
18
19# Options
20PLOT_FIGURE = True
21SAVE_FIGURE = True
22SAVE_RESULTS = False
23
24# setting up test params
25n_elem = 100
26start = np.zeros((3,))
27direction = np.array([0.0, 0.0, 1.0])
28normal = np.array([0.0, 1.0, 0.0])
29base_length = 100.0
30base_radius = 0.35
31base_area = np.pi * base_radius ** 2
32density = 1.0 / (base_area)
33nu = 0.01 / density / base_area
34E = 1e6
35slack = 3
36number_of_rotations = 27
37# For shear modulus of 1e5, nu is 99!
38poisson_ratio = 9
39shear_modulus = E / (poisson_ratio + 1.0)
40shear_matrix = np.repeat(
41 shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
42)
43temp_bend_matrix = np.zeros((3, 3))
44np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
45bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
46
47shearable_rod = ea.CosseratRod.straight_rod(
48 n_elem,
49 start,
50 direction,
51 normal,
52 base_length,
53 base_radius,
54 density,
55 youngs_modulus=E,
56 shear_modulus=shear_modulus,
57)
58# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
59
60shearable_rod.shear_matrix = shear_matrix
61shearable_rod.bend_matrix = bend_matrix
62
63
64helicalbuckling_sim.append(shearable_rod)
65# add damping
66dl = base_length / n_elem
67dt = 1e-3 * dl
68helicalbuckling_sim.dampen(shearable_rod).using(
69 ea.AnalyticalLinearDamper,
70 damping_constant=nu,
71 time_step=dt,
72)
73
74helicalbuckling_sim.constrain(shearable_rod).using(
75 ea.HelicalBucklingBC,
76 constrained_position_idx=(0, -1),
77 constrained_director_idx=(0, -1),
78 twisting_time=500,
79 slack=slack,
80 number_of_rotations=number_of_rotations,
81)
82
83helicalbuckling_sim.finalize()
84timestepper = ea.PositionVerlet()
85shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
86# timestepper = PEFRL()
87
88final_time = 10500.0
89total_steps = int(final_time / dt)
90print("Total steps", total_steps)
91ea.integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
92
93if PLOT_FIGURE:
94 plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
95
96if SAVE_RESULTS:
97 import pickle
98
99 filename = "HelicalBuckling_data.dat"
100 file = open(filename, "wb")
101 pickle.dump(shearable_rod, file)
102 file.close()
Continuum Snake#
1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
2
3import os
4import numpy as np
5import elastica as ea
6
7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
8 plot_snake_velocity,
9 plot_video,
10 compute_projected_velocity,
11 plot_curvature,
12)
13
14
15class SnakeSimulator(
16 ea.BaseSystemCollection,
17 ea.Constraints,
18 ea.Forcing,
19 ea.Damping,
20 ea.CallBacks,
21 ea.Contact,
22):
23 pass
24
25
26def run_snake(
27 b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
28):
29 # Initialize the simulation class
30 snake_sim = SnakeSimulator()
31
32 # Simulation parameters
33 period = 2
34 final_time = (11.0 + 0.01) * period
35
36 # setting up test params
37 n_elem = 50
38 start = np.zeros((3,))
39 direction = np.array([0.0, 0.0, 1.0])
40 normal = np.array([0.0, 1.0, 0.0])
41 base_length = 0.35
42 base_radius = base_length * 0.011
43 density = 1000
44 E = 1e6
45 poisson_ratio = 0.5
46 shear_modulus = E / (poisson_ratio + 1.0)
47
48 shearable_rod = ea.CosseratRod.straight_rod(
49 n_elem,
50 start,
51 direction,
52 normal,
53 base_length,
54 base_radius,
55 density,
56 youngs_modulus=E,
57 shear_modulus=shear_modulus,
58 )
59
60 snake_sim.append(shearable_rod)
61
62 # Add gravitational forces
63 gravitational_acc = -9.80665
64 snake_sim.add_forcing_to(shearable_rod).using(
65 ea.GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
66 )
67
68 # Add muscle torques
69 wave_length = b_coeff[-1]
70 snake_sim.add_forcing_to(shearable_rod).using(
71 ea.MuscleTorques,
72 base_length=base_length,
73 b_coeff=b_coeff[:-1],
74 period=period,
75 wave_number=2.0 * np.pi / (wave_length),
76 phase_shift=0.0,
77 rest_lengths=shearable_rod.rest_lengths,
78 ramp_up_time=period,
79 direction=normal,
80 with_spline=True,
81 )
82
83 # Add friction forces
84 ground_plane = ea.Plane(
85 plane_origin=np.array([0.0, -base_radius, 0.0]), plane_normal=normal
86 )
87 snake_sim.append(ground_plane)
88 slip_velocity_tol = 1e-8
89 froude = 0.1
90 mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
91 kinetic_mu_array = np.array(
92 [mu, 1.5 * mu, 2.0 * mu]
93 ) # [forward, backward, sideways]
94 static_mu_array = np.zeros(kinetic_mu_array.shape)
95 snake_sim.detect_contact_between(shearable_rod, ground_plane).using(
96 ea.RodPlaneContactWithAnisotropicFriction,
97 k=1.0,
98 nu=1e-6,
99 slip_velocity_tol=slip_velocity_tol,
100 static_mu_array=static_mu_array,
101 kinetic_mu_array=kinetic_mu_array,
102 )
103
104 # add damping
105 damping_constant = 2e-3
106 time_step = 1e-4
107 snake_sim.dampen(shearable_rod).using(
108 ea.AnalyticalLinearDamper,
109 damping_constant=damping_constant,
110 time_step=time_step,
111 )
112
113 total_steps = int(final_time / time_step)
114 rendering_fps = 60
115 step_skip = int(1.0 / (rendering_fps * time_step))
116
117 # Add call backs
118 class ContinuumSnakeCallBack(ea.CallBackBaseClass):
119 """
120 Call back function for continuum snake
121 """
122
123 def __init__(self, step_skip: int, callback_params: dict):
124 ea.CallBackBaseClass.__init__(self)
125 self.every = step_skip
126 self.callback_params = callback_params
127
128 def make_callback(self, system, time, current_step: int):
129
130 if current_step % self.every == 0:
131
132 self.callback_params["time"].append(time)
133 self.callback_params["step"].append(current_step)
134 self.callback_params["position"].append(
135 system.position_collection.copy()
136 )
137 self.callback_params["velocity"].append(
138 system.velocity_collection.copy()
139 )
140 self.callback_params["avg_velocity"].append(
141 system.compute_velocity_center_of_mass()
142 )
143
144 self.callback_params["center_of_mass"].append(
145 system.compute_position_center_of_mass()
146 )
147 self.callback_params["curvature"].append(system.kappa.copy())
148
149 return
150
151 pp_list = ea.defaultdict(list)
152 snake_sim.collect_diagnostics(shearable_rod).using(
153 ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
154 )
155
156 snake_sim.finalize()
157
158 timestepper = ea.PositionVerlet()
159 ea.integrate(timestepper, snake_sim, final_time, total_steps)
160
161 if PLOT_FIGURE:
162 filename_plot = "continuum_snake_velocity.png"
163 plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
164 plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
165
166 if SAVE_VIDEO:
167 filename_video = "continuum_snake.mp4"
168 plot_video(
169 pp_list,
170 video_name=filename_video,
171 fps=rendering_fps,
172 xlim=(0, 4),
173 ylim=(-1, 1),
174 )
175
176 if SAVE_RESULTS:
177 import pickle
178
179 filename = "continuum_snake.dat"
180 file = open(filename, "wb")
181 pickle.dump(pp_list, file)
182 file.close()
183
184 # Compute the average forward velocity. These will be used for optimization.
185 [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
186
187 return avg_forward, avg_lateral, pp_list
188
189
190if __name__ == "__main__":
191
192 # Options
193 PLOT_FIGURE = True
194 SAVE_FIGURE = True
195 SAVE_VIDEO = True
196 SAVE_RESULTS = False
197 CMA_OPTION = False
198
199 if CMA_OPTION:
200 import cma
201
202 SAVE_OPTIMIZED_COEFFICIENTS = False
203
204 def optimize_snake(spline_coefficient):
205 [avg_forward, _, _] = run_snake(
206 spline_coefficient,
207 PLOT_FIGURE=False,
208 SAVE_FIGURE=False,
209 SAVE_VIDEO=False,
210 SAVE_RESULTS=False,
211 )
212 return -avg_forward
213
214 # Optimize snake for forward velocity. In cma.fmin first input is function
215 # to be optimized, second input is initial guess for coefficients you are optimizing
216 # for and third input is standard deviation you initially set.
217 optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
218
219 # Save the optimized coefficients to a file
220 filename_data = "optimized_coefficients.txt"
221 if SAVE_OPTIMIZED_COEFFICIENTS:
222 assert filename_data != "", "provide a file name for coefficients"
223 np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
224
225 else:
226 # Add muscle forces on the rod
227 if os.path.exists("optimized_coefficients.txt"):
228 t_coeff_optimized = np.genfromtxt(
229 "optimized_coefficients.txt", delimiter=","
230 )
231 else:
232 wave_length = 1.0
233 t_coeff_optimized = np.array(
234 [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
235 )
236 t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
237
238 # run the simulation
239 [avg_forward, avg_lateral, pp_list] = run_snake(
240 t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
241 )
242
243 print("average forward velocity:", avg_forward)
244 print("average forward lateral:", avg_lateral)