Coverage for hdl_registers/generator/cpp/implementation.py: 96%
262 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-31 20:50 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-31 20:50 +0000
1# --------------------------------------------------------------------------------------------------
2# Copyright (c) Lukas Vik. All rights reserved.
3#
4# This file is part of the hdl-registers project, an HDL register generator fast enough to run
5# in real time.
6# https://hdl-registers.com
7# https://github.com/hdl-registers/hdl-registers
8# --------------------------------------------------------------------------------------------------
10from __future__ import annotations
12from typing import TYPE_CHECKING, Any, Literal
14from hdl_registers.field.bit import Bit
15from hdl_registers.field.bit_vector import BitVector
16from hdl_registers.field.enumeration import Enumeration
17from hdl_registers.field.integer import Integer
18from hdl_registers.field.numerical_interpretation import (
19 Signed,
20 SignedFixedPoint,
21 Unsigned,
22 UnsignedFixedPoint,
23)
25from .cpp_generator_common import CppGeneratorCommon
27if TYPE_CHECKING:
28 from pathlib import Path
30 from hdl_registers.field.register_field import RegisterField
31 from hdl_registers.register import Register
32 from hdl_registers.register_array import RegisterArray
35class CppImplementationGenerator(CppGeneratorCommon):
36 """
37 Generate a C++ class implementation.
38 See the :ref:`generator_cpp` article for usage details.
40 The class implementation will contain:
42 * for each register, implementation of getter and setter methods for reading/writing the
43 register as an ``uint``.
45 * for each field in each register, implementation of getter and setter methods for
46 reading/writing the field as its native type (enumeration, positive/negative int, etc.).
48 * The setter will read-modify-write the register to update only the specified field,
49 depending on the mode of the register.
50 """
52 __version__ = "2.0.2"
54 SHORT_DESCRIPTION = "C++ implementation"
56 DEFAULT_INDENTATION_LEVEL = 2
58 @property
59 def output_file(self) -> Path:
60 """
61 Result will be placed in this file.
62 """
63 return self.output_folder / f"{self.name}.cpp"
65 def get_code(
66 self,
67 **kwargs: Any, # noqa: ANN401, ARG002
68 ) -> str:
69 """
70 Get a complete C++ class implementation with all methods.
71 """
72 cpp_code = f"""\
73{self._get_macros()}\
74 {self._class_name}::{self._constructor_signature()}
75 : m_registers(reinterpret_cast<volatile uint32_t *>(base_address)),
76 m_assertion_handler(assertion_handler)
77 {
78 // Empty
79 }
80"""
82 separator = self.get_separator_line()
83 for register, register_array in self.iterate_registers():
84 cpp_code += self._get_register_heading(
85 register=register, register_array=register_array, separator=separator
86 )
88 methods_cpp: list[str] = []
90 if register.mode.software_can_read:
91 methods_cpp.append(
92 self._get_register_getter(register=register, register_array=register_array)
93 )
95 if register.fields:
96 # The main getter will perform type conversion.
97 # Provide a getter that returns the raw value also.
98 methods_cpp.append(
99 self._get_register_raw_getter(
100 register=register, register_array=register_array
101 )
102 )
104 for field in register.fields:
105 methods_cpp.append(
106 self._get_field_getter(
107 register=register, register_array=register_array, field=field
108 )
109 )
110 methods_cpp.append(
111 self._get_field_getter_from_raw(register, register_array, field=field)
112 )
114 if register.mode.software_can_write:
115 methods_cpp.append(
116 self._get_register_setter(register=register, register_array=register_array)
117 )
119 if register.fields:
120 # The main getter will perform type conversion.
121 # Provide a setter that takes a raw value also.
122 methods_cpp.append(
123 self._get_register_raw_setter(
124 register=register, register_array=register_array
125 )
126 )
128 for field in register.fields:
129 methods_cpp.append(
130 self._get_field_setter(
131 register=register, register_array=register_array, field=field
132 )
133 )
134 methods_cpp.append(
135 self._get_field_to_raw(register, register_array, field=field)
136 )
138 cpp_code += "\n".join(methods_cpp)
139 cpp_code += separator
141 cpp_code += "\n"
142 cpp_code_top = f'#include "include/{self.name}.h"\n\n'
144 return cpp_code_top + self._with_namespace(cpp_code)
146 def _get_macros(self) -> str:
147 def get_macro(name: str) -> str:
148 macro_name = f"_{name}_ASSERT_TRUE"
149 return f"""\
150#ifdef NO_REGISTER_{name}_ASSERT
151 #define {macro_name}(expression, message) ((void)0)
152#else
153 #define {macro_name}(expression, message) (_ASSERT_TRUE(expression, message))
154#endif
155"""
157 setter_assert = get_macro(name="SETTER")
158 getter_assert = get_macro(name="GETTER")
159 array_index_assert = get_macro(name="ARRAY_INDEX")
161 file_name = self.output_file.name
162 file_name_space = " " * (31 - len(file_name))
163 assert_true = f"""\
164#define _ASSERT_TRUE(expression, message) \\
165 { \\
166 if (!static_cast<bool>(expression)) { \\
167 std::ostringstream diagnostics; \\
168 diagnostics << "{file_name}:" << __LINE__ << ": " << message << "."; {file_name_space}\\
169 std::string diagnostic_message = diagnostics.str(); \\
170 m_assertion_handler(&diagnostic_message); \\
171 } \\
172 }
173"""
175 return f"""\
176// Macros called by the register code below to check for runtime errors.
177{setter_assert}
178{getter_assert}
179{array_index_assert}
180// Base macro called by the other macros.
181{assert_true}
182"""
184 def _get_register_getter(self, register: Register, register_array: RegisterArray | None) -> str:
185 comment = self._get_getter_comment()
186 return_type = self._get_register_value_type(
187 register=register, register_array=register_array
188 )
189 signature = self._register_getter_signature(
190 register=register, register_array=register_array
191 )
193 if register.fields:
194 raw_value = self._get_read_raw_value_call(
195 register=register, register_array=register_array
196 )
198 fields = ""
199 values: list[str] = []
200 for field in register.fields:
201 field_type = self._get_field_value_type(
202 register=register, register_array=register_array, field=field
203 )
204 getter_name = self._field_getter_name(
205 register=register, register_array=register_array, field=field, from_raw=True
206 )
207 fields += f" const {field_type} {field.name}_value = {getter_name}(raw_value);\n"
208 values.append(f"{field.name}_value")
210 value = ", ".join(values)
211 result = f"""\
212{raw_value}
213{fields}
214 return { {value}} ;\
215"""
216 else:
217 raw_value = self._get_read_raw_value_code(
218 register=register, register_array=register_array
219 )
220 result = f"""\
221{raw_value}
222 return raw_value;\
223"""
225 return f"""\
226{comment}\
227 {return_type} {self._class_name}::{signature}
228 {
229{result}
230 }
231"""
233 def _get_read_raw_value_call(
234 self, register: Register, register_array: RegisterArray | None
235 ) -> str:
236 getter_name = self._register_getter_name(
237 register=register, register_array=register_array, raw=True
238 )
239 array_index = "array_index" if register_array else ""
240 return f"""\
241 const uint32_t raw_value = {getter_name}({array_index});
242"""
244 def _get_read_raw_value_code(
245 self,
246 register: Register,
247 register_array: RegisterArray | None,
248 include_index: bool = True,
249 ) -> str:
250 index = (
251 self._get_index(register=register, register_array=register_array)
252 if include_index
253 else ""
254 )
255 return f"""\
256{index}\
257 const uint32_t raw_value = m_registers[index];
258"""
260 def _get_index(self, register: Register, register_array: RegisterArray | None) -> str:
261 if register_array:
262 checker = f"""\
263 _ARRAY_INDEX_ASSERT_TRUE(
264 array_index < {self.name}::{register_array.name}::array_length,
265 "Got '{register_array.name}' array index out of range: " << array_index
266 );
267"""
268 index = (
269 f"{register_array.base_index} "
270 f"+ array_index * {len(register_array.registers)} + {register.index}"
271 )
272 else:
273 checker = ""
274 index = str(register.index)
276 return f"""\
277{checker}\
278 const size_t index = {index};
279"""
281 def _get_register_raw_getter(
282 self, register: Register, register_array: RegisterArray | None
283 ) -> str:
284 comment = self._get_getter_comment(raw=True)
285 signature = self._register_getter_signature(
286 register=register, register_array=register_array, raw=True
287 )
288 raw_value = self._get_read_raw_value_code(register=register, register_array=register_array)
290 return f"""\
291{comment}\
292 uint32_t {self._class_name}::{signature}
293 {
294{raw_value}
295 return raw_value;
296 }
297"""
299 def _get_field_getter(
300 self, register: Register, register_array: RegisterArray | None, field: RegisterField
301 ) -> str:
302 comment = self._get_getter_comment(field=field)
303 field_type = self._get_field_value_type(
304 register=register, register_array=register_array, field=field
305 )
306 signature = self._field_getter_signature(
307 register=register,
308 register_array=register_array,
309 field=field,
310 from_raw=False,
311 )
312 raw_value = self._get_read_raw_value_call(register=register, register_array=register_array)
313 from_raw_name = self._field_getter_name(
314 register=register, register_array=register_array, field=field, from_raw=True
315 )
317 return f"""\
318{comment}\
319 {field_type} {self._class_name}::{signature}
320 {
321{raw_value}
322 return {from_raw_name}(raw_value);
323 }
324"""
326 def _get_field_getter_from_raw(
327 self, register: Register, register_array: RegisterArray | None, field: RegisterField
328 ) -> str:
329 namespace = self._get_namespace(
330 register=register, register_array=register_array, field=field
331 )
332 comment = self._get_from_raw_comment(field=field)
333 field_type = self._get_field_value_type(
334 register=register, register_array=register_array, field=field
335 )
336 signature = self._field_getter_signature(
337 register=register, register_array=register_array, field=field, from_raw=True
338 )
339 cast = self._get_from_raw_cast(field=field, field_type=field_type)
340 checker = self._get_field_checker(field=field, setter_or_getter="getter")
342 return f"""\
343{comment}\
344 {field_type} {self._class_name}::{signature}
345 {
346 const uint32_t result_masked = register_value & {namespace}mask_shifted;
347 const uint32_t result_shifted = result_masked >> {namespace}shift;
349{cast}
350{checker}\
351 return field_value;
352 }
353"""
355 def _get_from_raw_cast(self, field: RegisterField, field_type: str) -> str: # noqa: PLR0911
356 no_cast = """\
357 // No casting needed.
358 const uint32_t field_value = result_shifted;
359"""
361 if isinstance(field, Bit):
362 return """\
363 // Convert to the result type.
364 const bool field_value = static_cast<bool>(result_shifted);
365"""
367 if isinstance(field, BitVector):
368 if isinstance(field.numerical_interpretation, Unsigned):
369 return no_cast
371 if isinstance(field.numerical_interpretation, Signed):
372 return self._get_field_to_negative(field=field)
374 if isinstance(field.numerical_interpretation, UnsignedFixedPoint):
375 return self._get_field_to_real(
376 field=field, field_type=field_type, variable="result_shifted"
377 )
379 if isinstance(field.numerical_interpretation, SignedFixedPoint):
380 return (
381 self._get_field_to_negative(field=field, variable="result_negative")
382 + "\n"
383 + self._get_field_to_real(
384 field=field, field_type=field_type, variable="result_negative"
385 )
386 )
388 raise TypeError(
389 f"Got unexpected numerical interpretation type: {field.numerical_interpretation}"
390 )
392 if isinstance(field, Enumeration):
393 return f"""\
394 // "Cast" to the enum type.
395 const auto field_value = {field_type}(result_shifted);
396"""
398 if isinstance(field, Integer):
399 if field.is_signed:
400 return self._get_field_to_negative(field=field)
402 return no_cast
404 raise TypeError(f"Got unexpected field type: {field}")
406 def _get_field_to_negative(self, field: Integer, variable: str = "field_value") -> str:
407 # Note that the shift result has maximum value of '1 << 31', which always
408 # fits in a 32-bit unsigned integer.
409 return f"""\
410 const uint32_t sign_bit_mask = 1uL << {field.width - 1};
411 int32_t {variable};
412 if (result_shifted & sign_bit_mask)
413 {
414 // Value is to be interpreted as negative.
415 // This can be seen as a sign extension from the width of the field to the width of
416 // the result variable.
417 {variable} = result_shifted - 2 * sign_bit_mask;
418 }
419 else
420 {
421 // Value is positive.
422 {variable} = result_shifted;
423 }
424"""
426 def _get_field_to_real(self, field: BitVector, field_type: str, variable: str) -> str:
427 divisor = 2**field.numerical_interpretation.fraction_bit_width
428 return f"""\
429 const {field_type} result_real = static_cast<{field_type}>({variable});
430 const {field_type} field_value = result_real / {divisor};
431"""
433 def _get_field_checker(self, field: str, setter_or_getter: Literal["setter", "getter"]) -> str:
434 min_check, max_check = self._get_field_checker_limits(
435 field=field, is_getter_not_setter=setter_or_getter == "getter"
436 )
438 checks: list[str] = []
439 if min_check is not None:
440 checks.append(f"field_value >= {min_check}")
441 if max_check is not None:
442 checks.append(f"field_value <= {max_check}")
443 check = " && ".join(checks)
445 if not check:
446 return ""
448 macro = f"_{setter_or_getter.upper()}_ASSERT_TRUE"
449 return f"""\
450 {macro}(
451 {check},
452 "Got '{field.name}' value out of range: " << field_value
453 );
455"""
457 def _get_field_checker_limits( # noqa: C901, PLR0911
458 self, field: RegisterField, is_getter_not_setter: bool
459 ) -> tuple[float | None, float | None]:
460 """
461 Return minimum and maximum values for checking.
462 ``None`` if no check is needed.
463 """
464 width_matches_cpp_type = field.width == 32
466 if isinstance(field, Bit):
467 # Values is represented as boolean in C++, and in HDL it is a single bit.
468 # Can not be out of range in either direction.
469 return None, None
471 if isinstance(field, BitVector):
472 if is_getter_not_setter:
473 # Result of bit slice can not be out of range by definition.
474 return None, None
476 if isinstance(field.numerical_interpretation, Unsigned):
477 return (
478 # Min is always zero, and unsigned type is used in C++.
479 None,
480 None if width_matches_cpp_type else field.numerical_interpretation.max_value,
481 )
483 if isinstance(field.numerical_interpretation, Signed):
484 return (
485 None if width_matches_cpp_type else field.numerical_interpretation.min_value,
486 None if width_matches_cpp_type else field.numerical_interpretation.max_value,
487 )
489 if isinstance(field.numerical_interpretation, (UnsignedFixedPoint, SignedFixedPoint)):
490 # Value is represented as floating-point in C++, which is always a signed type.
491 # And we have no guarantees about the range of the type.
492 # Hence we must always check.
493 return (
494 field.numerical_interpretation.min_value,
495 field.numerical_interpretation.max_value,
496 )
498 raise TypeError(
499 f"Got unexpected numerical interpretation type: {field.numerical_interpretation}"
500 )
502 if isinstance(field, Enumeration):
503 max_value = field.elements[-1].value
505 _, max_is_native = self._get_checker_limits_are_native(
506 field=field,
507 min_value=0,
508 max_value=max_value,
509 is_signed=False,
510 )
512 if is_getter_not_setter:
513 return (
514 None,
515 # Result of bit slice can not be out of range if native.
516 None if max_is_native else max_value,
517 )
519 return (
520 # Assume that the C++ type used is unsigned.
521 None,
522 # In C++, while being typed, the setter argument could be out of range, regardless
523 # if native limit or not.
524 # We can not know the width of the C++ type used. Would depend on compiler/platform.
525 max_value,
526 )
528 if isinstance(field, Integer):
529 min_is_native, max_is_native = self._get_checker_limits_are_native(
530 field=field,
531 min_value=field.min_value,
532 max_value=field.max_value,
533 is_signed=field.is_signed,
534 )
536 if is_getter_not_setter:
537 # Result of bit slice can not be out of range if native.
538 return (
539 None if min_is_native else field.min_value,
540 None if max_is_native else field.max_value,
541 )
543 if field.is_signed:
544 # Have to be checked in general, except for the corner case where
545 # the field width matches the C++ type.
546 return (
547 None if min_is_native and width_matches_cpp_type else field.min_value,
548 None if max_is_native and width_matches_cpp_type else field.max_value,
549 )
551 return (
552 # Min is always zero, and unsigned type is used in C++.
553 None,
554 # Has to be checked in general, except for the corner case.
555 None if max_is_native and width_matches_cpp_type else field.max_value,
556 )
558 raise TypeError(f"Got unexpected field type: {field}")
560 def _get_checker_limits_are_native(
561 self, field: RegisterField, min_value: float, max_value: float, is_signed: bool
562 ) -> tuple[bool, bool]:
563 if is_signed:
564 native_min_value = -(2 ** (field.width - 1))
565 native_max_value = 2 ** (field.width - 1) - 1
566 else:
567 native_min_value = 0
568 native_max_value = 2**field.width - 1
570 return min_value == native_min_value, max_value == native_max_value
572 def _get_register_setter(self, register: Register, register_array: RegisterArray | None) -> str:
573 comment = self._get_setter_comment(register=register)
574 signature = self._register_setter_signature(
575 register=register, register_array=register_array
576 )
578 if register.fields:
579 cast = ""
580 values: list[str] = []
581 for field in register.fields:
582 to_raw_name = self._field_to_raw_name(
583 register=register, register_array=register_array, field=field
584 )
585 cast += (
586 f" const uint32_t {field.name}_value = "
587 f"{to_raw_name}(register_value.{field.name});\n"
588 )
589 values.append(f"{field.name}_value")
591 value = " | ".join(values)
592 cast += f"""\
593 const uint32_t raw_value = {value};
595"""
596 set_raw_value = self._get_write_raw_value_call(
597 register=register, register_array=register_array
598 )
599 else:
600 cast = ""
601 set_raw_value = self._get_write_raw_value_code(
602 register=register, register_array=register_array
603 )
605 return f"""\
606{comment}\
607 void {self._class_name}::{signature}
608 {
609{cast}\
610{set_raw_value}\
611 }
612"""
614 def _get_write_raw_value_call(
615 self, register: Register, register_array: RegisterArray | None
616 ) -> str:
617 setter_name = self._register_setter_name(
618 register=register, register_array=register_array, raw=True
619 )
620 array_index = "array_index, " if register_array else ""
621 return f"""\
622 {setter_name}({array_index}raw_value);
623"""
625 def _get_write_raw_value_code(
626 self,
627 register: Register,
628 register_array: RegisterArray | None,
629 include_index: bool = True,
630 ) -> str:
631 index = (
632 self._get_index(register=register, register_array=register_array)
633 if include_index
634 else ""
635 )
636 return f"""\
637{index}\
638 m_registers[index] = register_value;
639"""
641 def _get_register_raw_setter(
642 self, register: Register, register_array: RegisterArray | None
643 ) -> str:
644 comment = self._get_setter_comment(register=register, raw=True)
645 signature = self._register_setter_signature(
646 register=register, register_array=register_array, raw=True
647 )
648 set_raw_value = self._get_write_raw_value_code(
649 register=register, register_array=register_array
650 )
652 return f"""\
653{comment}\
654 void {self._class_name}::{signature}
655 {
656{set_raw_value}\
657 }
658"""
660 def _get_field_setter(
661 self, register: Register, register_array: RegisterArray | None, field: RegisterField
662 ) -> str:
663 comment = self._get_setter_comment(register=register, field=field)
664 signature = self._field_setter_signature(
665 register=register,
666 register_array=register_array,
667 field=field,
668 from_raw=False,
669 )
670 index = self._get_index(register=register, register_array=register_array)
672 if self.field_setter_should_read_modify_write(register=register):
673 namespace = self._get_namespace(
674 register=register, register_array=register_array, field=field
675 )
676 raw_value = self._get_read_raw_value_code(
677 register=register, register_array=register_array, include_index=False
678 )
679 base_value = f"""\
680{raw_value}\
681 const uint32_t mask_shifted_inverse = ~{namespace}mask_shifted;
682 const uint32_t base_value = raw_value & mask_shifted_inverse;
683"""
685 else:
686 default_values = []
687 for loop_field in register.fields:
688 if loop_field.name != field.name:
689 namespace = self._get_namespace(
690 register=register, register_array=register_array, field=loop_field
691 )
692 default_values.append(f"{namespace}default_value_raw")
694 # The '0' is needed in case there are no fields other than the one we are writing.
695 default_value = " | ".join(default_values) if default_values else "0"
696 base_value = f"""\
697 const uint32_t base_value = {default_value};
698"""
700 to_raw_name = self._field_to_raw_name(
701 register=register, register_array=register_array, field=field
702 )
703 write_raw_value = self._get_write_raw_value_code(
704 register=register, register_array=register_array, include_index=False
705 )
707 return f"""\
708{comment}\
709 void {self._class_name}::{signature}
710 {
711{index}
712{base_value}\
714 const uint32_t field_value_raw = {to_raw_name}(field_value);
715 const uint32_t register_value = base_value | field_value_raw;
717{write_raw_value}\
718 }
719"""
721 def _get_field_to_raw(
722 self, register: Register, register_array: RegisterArray | None, field: RegisterField
723 ) -> str:
724 comment = self._get_to_raw_comment(field=field)
725 signature = self._field_to_raw_signature(
726 register=register, register_array=register_array, field=field
727 )
728 namespace = self._get_namespace(
729 register=register, register_array=register_array, field=field
730 )
731 checker = self._get_field_checker(field=field, setter_or_getter="setter")
732 cast, variable = self._get_to_raw_cast(
733 register=register, register_array=register_array, field=field
734 )
736 return f"""\
737{comment}\
738 uint32_t {self._class_name}::{signature}
739 {
740{checker}\
741{cast}\
742 const uint32_t field_value_shifted = {variable} << {namespace}shift;
744 return field_value_shifted;
745 }
746"""
748 def _get_to_raw_cast( # noqa: C901, PLR0911
749 self, register: Register, register_array: RegisterArray | None, field: RegisterField
750 ) -> tuple[str, str]:
751 # Useful for values that are in an unsigned integer representation, but not
752 # explicitly 'uint32_t'.
753 cast_to_uint32 = """\
754 const uint32_t field_value_casted = static_cast<uint32_t>(field_value);
755"""
757 def _get_reinterpret_as_uint32(variable: str = "field_value") -> str:
758 # Reinterpret as unsigned and then mask out all the sign extended bits above
759 # the field. Useful for signed integer values.
760 # Signed to unsigned static cast produces no change in the bit pattern
761 # https://stackoverflow.com/a/1751368
762 namespace = self._get_namespace(
763 register=register, register_array=register_array, field=field
764 )
765 return f"""\
766 const uint32_t field_value_unsigned = (uint32_t){variable};
767 const uint32_t field_value_masked = field_value_unsigned & {namespace}mask_at_base;
768"""
770 if isinstance(field, Bit):
771 return (cast_to_uint32, "field_value_casted")
773 if isinstance(field, BitVector):
774 if isinstance(field.numerical_interpretation, Unsigned):
775 return ("", "field_value")
777 if isinstance(field.numerical_interpretation, Signed):
778 return (_get_reinterpret_as_uint32(), "field_value_masked")
780 value_type = self._get_field_value_type(
781 register=register, register_array=register_array, field=field
782 )
783 multiplier = 2**field.numerical_interpretation.fraction_bit_width
785 fixed_type = "int32_t" if field.numerical_interpretation.is_signed else "uint32_t"
787 # Static cast implies truncation, which should guarantee that the
788 # fixed-point representation fits in the field.
789 to_fixed = f"""\
790 const {value_type} field_value_multiplied = field_value * {multiplier};
791 const {fixed_type} field_value_fixed = static_cast<{fixed_type}>(field_value_multiplied);
792"""
794 if isinstance(field.numerical_interpretation, UnsignedFixedPoint):
795 return (to_fixed, "field_value_fixed")
797 if isinstance(field.numerical_interpretation, SignedFixedPoint):
798 return (
799 to_fixed + _get_reinterpret_as_uint32(variable="field_value_fixed"),
800 "field_value_masked",
801 )
803 raise TypeError(
804 f"Got unexpected numerical interpretation: {field.numerical_interpretation}"
805 )
807 if isinstance(field, Enumeration):
808 return (cast_to_uint32, "field_value_casted")
810 if isinstance(field, Integer):
811 if field.is_signed:
812 return (_get_reinterpret_as_uint32(), "field_value_masked")
814 return ("", "field_value")
816 raise TypeError(f"Got unexpected field type: {field}")