Source code for pydy.codegen.matrix_generator

#!/usr/bin/env python

import itertools

import sympy as sm
import sympy.physics.mechanics as me
from sympy.printing.codeprinter import CodePrinter

from ..utils import wrap_and_indent, find_dynamicsymbols


[docs]class MatrixGenerator(object): """This abstract base class generates source files that simultaneously numerically evaluate any number of SymPy matrices. """ _idx_start = 0 _idx_delim = "[]" _base_printer = CodePrinter # TODO : I opened an issue with SymPy to deal with the type declaration, # https://github.com/sympy/sympy/issues/10446. _type_declar = 'double ' # prepended to variable introductions _line_contin = None _comment_char = '#'
[docs] def __init__(self, arguments, matrices, cse=True): """ Parameters ========== arguments : sequence of sequences of SymPy Symbol or Function Each of the sequences will be converted to input arrays in the generated function. All of the symbols/functions contained in ``matrices`` need to be in the sequences, but the sequences can also contain extra symbols/functions that are not contained in the matrices. matrices : sequence of SymPy.Matrix A sequence of the matrices that should be evaluated in the function. The expressions should contain only sympy.Symbol or sympy.Function that are functions of me.dynamicsymbols._t. cse : boolean Find and replace common sub-expressions in ``matrices`` if True. """ required_args = set() for matrix in matrices: # TODO : SymPy 0.7.4 does not have Matrix.free_symbols so we # manually compute them instead of calling: # required_args |= matrix.free_symbols required_args |= set().union(*[i.free_symbols for i in matrix]) required_args |= find_dynamicsymbols(matrix) required_args.remove(me.dynamicsymbols._t) all_arguments = set(itertools.chain(*arguments)) for required_arg in required_args: if required_arg not in all_arguments: msg = "{} is missing from the argument sequences." raise ValueError(msg.format(required_arg)) self.matrices = matrices self.arguments = arguments if cse: self._generate_cse() else: self._ignore_cse() self._generate_code_blocks()
def _generate_cse(self, prefix='pydy_'): # This makes a big long list of every expression in all of the # matrices. exprs = [] for matrix in self.matrices: exprs += matrix[:] # Compute the common subexpresions. gen = sm.numbered_symbols(prefix) self.subexprs, simplified_exprs = sm.cse(exprs, symbols=gen) # Turn the expressions back into matrices of the same type. simplified_matrices = [] idx = 0 for matrix in self.matrices: num_rows, num_cols = matrix.shape length = num_rows * num_cols m = type(matrix)(simplified_exprs[idx:idx + length]) simplified_matrices.append(m.reshape(num_rows, num_cols)) idx += length self.simplified_matrices = tuple(simplified_matrices) def _ignore_cse(self): self.subexprs = [] self.simplified_matrices = self.matrices def _generate_pydy_printer(self): """Returns a subclass of a sympy.printing.Printer to print appropriate array index calls for all of the symbols in the equations of motion.""" array_index_map = {} for i, arg_set in enumerate(self.arguments): for j, var in enumerate(arg_set): array_index_map[var] = r'input_{}{}{}{}'.format( i + self._idx_start, self._idx_delim[0], j + self._idx_start, self._idx_delim[1]) class PyDyCodePrinter(self._base_printer): def _print_Function(self, e): if e in array_index_map.keys(): return array_index_map[e] else: return super(PyDyCodePrinter, self)._print_Function(e) def _print_Symbol(self, e): if e in array_index_map.keys(): return array_index_map[e] else: return super(PyDyCodePrinter, self)._print_Symbol(e) return PyDyCodePrinter
[docs] def comma_lists(self): """Returns a string output for each of the sequences of SymPy arguments.""" comma_lists = [] for i, arg_set in enumerate(self.arguments): comma_lists.append(', '.join([str(s) for s in arg_set])) return tuple(comma_lists)
def _generate_code_blocks(self): """Writes the blocks of code for the source file.""" printer = self._generate_pydy_printer()() self.code_blocks = {} # Creates a string of comma separated inputs. lines = [] for i, input_arg in enumerate(self.arguments): lines.append('{}input_{}'.format(self._type_declar, i + self._idx_start)) self.code_blocks['input_args'] = ', '.join(lines) # Creates a string of comma separated outputs. lines = [] for i, output_arg in enumerate(self.matrices): lines.append('{}output_{}'.format(self._type_declar, i + self._idx_start)) self.code_blocks['output_args'] = ', '.join(lines) lines = [] for i, (input_arg, explan) in enumerate(zip(self.arguments, self.comma_lists())): lines.append('{} input_{} : [{}]'.format(self._comment_char, i + self._idx_start, explan)) self.code_blocks['docstring'] = wrap_and_indent(lines, 0, comment=self._comment_char) lines = [] for var, expr in self.subexprs: lines.append(printer.doprint(expr, var)) self.code_blocks['subexprs'] = wrap_and_indent(lines, continuation=self._line_contin) outputs = '' for i, output in enumerate(self.simplified_matrices): nr, nc = output.shape lhs = sm.MatrixSymbol('output_{}'.format(i + self._idx_start), nr, nc) try: code_str = printer.doprint(output, lhs) except AttributeError: # The above fails in SymPy 0.7.4.1 because Matrix printing # isn't supported. code_lines = [] for j, element in enumerate(output): assignment = 'output_{}[{}]'.format(i, j) code_lines.append(printer.doprint(element, assignment)) code_str = '\n'.join(code_lines) outputs += wrap_and_indent(code_str.split('\n'), continuation=self._line_contin) if i != len(self.simplified_matrices) - 1: outputs += '\n\n' # space between each output self.code_blocks['outputs'] = outputs