Source code for elastica.transformations

__doc__ = """ Rotation interface functions"""

import numpy as np

from elastica._rotations import (
    _inv_skew_symmetrize,
    _skew_symmetrize,
    _rotate,
)

from .utils import MaxDimension, isqrt


# TODO Complete, but nicer interface, evolve it eventually


[docs]def format_vector_shape(vector_collection): """ Function for formatting vector shapes into correct format Parameters ---------- vector_collection: numpy.ndarray Can be 1D or 2D. Returns ------- output: numpy.ndarray Can be 1D or 2D. """ n_dim = vector_collection.ndim if n_dim == 1: # Shape is (dim,) vector_collection = np.expand_dims(vector_collection, axis=1) elif n_dim == 2: # First possibilty, shape is (blocksize, dim), with dim # Soft fix : resize always so that first dimension is least if vector_collection.shape[0] > max( MaxDimension.value(), vector_collection.shape[1] ): vector_collection = vector_collection.T # Second possibility, shape is (blocksize,dim), with blocksize<dim # Example row vector (1,3),(2,3) if ( vector_collection.shape[0] < MaxDimension.value() and vector_collection.shape[1] == MaxDimension.value() ): vector_collection = vector_collection.T elif n_dim > 2: raise RuntimeError("Vector collection dimensions >2 are not supported") # Check for pure 3D cases for now assert ( vector_collection.shape[0] == MaxDimension.value() ), "Need first dimension = 3" return vector_collection
[docs]def format_matrix_shape(matrix_collection): """ Formats input matrix into correct format Parameters ---------- matrix_collection: numpy.ndarray Can be 1D, 2D, 3D. Returns ------- """ n_dim = matrix_collection.ndim # check first two dimensions are same and matrix is square # other possibility is one dimension is dim**2 and other is blocksize, # we need to convert the matrix in that case. def assert_proper_square(num1): sqrt_num = isqrt(num1) assert sqrt_num ** 2 == num1, "Matrix dimension passed is not a perfect square" return sqrt_num if n_dim == 1: # Shape is (dim**2, ) # Check if dim**2 is a perfect square dim = assert_proper_square(matrix_collection.shape[0]) # Now reshape matrix accordingly to fit (dim, dim, 1) matrix_collection = np.atleast_3d(matrix_collection).reshape(dim, dim, 1) if n_dim == 2: # Check if we already have a square matrix or not, i.e. (3,3) if matrix_collection.shape[0] == matrix_collection.shape[1]: dim = matrix_collection.shape[0] else: # First possibilty, shape is (blocksize, dim**2) # Soft fix : resize always so that first dimension is least if matrix_collection.shape[0] > max( MaxDimension.value() ** 2, matrix_collection.shape[1] ): matrix_collection = matrix_collection.T # Check if dim**2 is not a perfect square dim = assert_proper_square(matrix_collection.shape[0]) # Expand to three dimensions # inp : (dim,dim) or (dim**2, bs) # op : (dim, dim, bs) matrix_collection = matrix_collection.reshape(dim, dim, -1) if n_dim == 3: # First possibilty, shape is (blocksize, dim, dim) if matrix_collection.shape[0] > max( MaxDimension.value(), matrix_collection.shape[1] ) and matrix_collection.shape[0] > max( MaxDimension.value(), matrix_collection.shape[2] ): matrix_collection = matrix_collection.T # Given (dim, dim, bs) array, check if dimensions are equal assert ( matrix_collection.shape[0] == matrix_collection.shape[1] ), "Matrix shapes along 1 and 2 are not equal" # Obtain dimensions for checking dim = matrix_collection.shape[0] elif n_dim > 3: raise RuntimeError("Matrix dimensions >3 are not supported") assert ( dim == MaxDimension.value() ), "Need matrix dimension = 3 for example (9,), (3,3), (3,3,1), (9,n), (3,3,n)" return matrix_collection
def skew_symmetrize(vector): vector = format_vector_shape(vector) return _skew_symmetrize(vector)
[docs]def inv_skew_symmetrize(matrix_collection): """ Safe wrapper around inv_skew_symmetrize that does checking and formatting on type of matrix_collection using format_matrix_shape function. Parameters ---------- matrix_collection: numpy.ndarray Returns ------- """ # format matrix collection into correct shape matrix_collection = format_matrix_shape(matrix_collection) # No memory allocated matrix_collection_t = np.einsum("ijk->jik", matrix_collection) # Checks, but 'b' argument allocates memory if np.allclose(matrix_collection, -matrix_collection_t): return _inv_skew_symmetrize(matrix_collection) else: raise ValueError("matrix_collection passed is not skew-symmetric")
[docs]def rotate(matrix, scale, axis): """ This function takes single or multiple frames as matrix. Then rotates these frames around a single axis for all frames, or can rotate each frame around its own rotation axis as defined by user. Scale determines how much frames rotates around this axis. matrix: minimum shape = dim**2x1, supports shape = 3x3xn axis: minimum dim = 3x1, 1x3, supports dim = 3xn, nx3 scale: minimum float, supports 1D vectors also dim = n """ matrix = format_matrix_shape(matrix) axis = format_vector_shape(axis) return _rotate(matrix, scale, axis)