Coverage for hdl_registers/generator/cpp/implementation.py: 97%
146 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 11:11 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 11:11 +0000
1# --------------------------------------------------------------------------------------------------
2# Copyright (c) Lukas Vik. All rights reserved.
3#
4# This file is part of the hdl-registers project, an HDL register generator fast enough to run
5# in real time.
6# https://hdl-registers.com
7# https://github.com/hdl-registers/hdl-registers
8# --------------------------------------------------------------------------------------------------
10from __future__ import annotations
12from typing import TYPE_CHECKING, Any
14from hdl_registers.field.bit import Bit
15from hdl_registers.field.bit_vector import BitVector
16from hdl_registers.field.enumeration import Enumeration
17from hdl_registers.field.integer import Integer
19from .cpp_generator_common import CppGeneratorCommon
21if TYPE_CHECKING:
22 from pathlib import Path
24 from hdl_registers.field.register_field import RegisterField
25 from hdl_registers.register import Register
26 from hdl_registers.register_array import RegisterArray
29class CppImplementationGenerator(CppGeneratorCommon):
30 """
31 Generate a C++ class implementation.
32 See the :ref:`generator_cpp` article for usage details.
34 The class implementation will contain:
36 * for each register, implementation of getter and setter methods for reading/writing the
37 register as an ``uint``.
39 * for each field in each register, implementation of getter and setter methods for
40 reading/writing the field as its native type (enumeration, positive/negative int, etc.).
42 * The setter will read-modify-write the register to update only the specified field,
43 depending on the mode of the register.
44 """
46 __version__ = "2.0.2"
48 SHORT_DESCRIPTION = "C++ implementation"
50 DEFAULT_INDENTATION_LEVEL = 4
52 @property
53 def output_file(self) -> Path:
54 """
55 Result will be placed in this file.
56 """
57 return self.output_folder / f"{self.name}.cpp"
59 def get_code(
60 self,
61 **kwargs: Any, # noqa: ANN401, ARG002
62 ) -> str:
63 """
64 Get a complete C++ class implementation with all methods.
65 """
66 cpp_code = f"""\
67{self._macros()}\
68 {self._class_name}::{self._constructor_signature()}
69 : m_registers(reinterpret_cast<volatile uint32_t *>(base_address)),
70 m_assertion_handler(assertion_handler)
71 {
72 // Empty
73 }
75"""
77 for register, register_array in self.iterate_registers():
78 cpp_code += f"{self.get_separator_line(indent=2)}"
80 description = self._get_methods_description(
81 register=register, register_array=register_array
82 )
83 cpp_code += self.comment_block(
84 text=[description, "See interface header for documentation."], indent=2
85 )
86 cpp_code += "\n"
88 if register.mode.software_can_read:
89 cpp_code += self._register_getter_function(register, register_array)
91 for field in register.fields:
92 cpp_code += self._field_getter_function(register, register_array, field=field)
93 cpp_code += self._field_getter_function_from_value(
94 register, register_array, field=field
95 )
97 if register.mode.software_can_write:
98 cpp_code += self._register_setter_function(register, register_array)
100 for field in register.fields:
101 cpp_code += self._field_setter_function(register, register_array, field=field)
102 cpp_code += self._field_setter_function_from_value(
103 register, register_array, field=field
104 )
106 cpp_code_top = f'#include "include/{self.name}.h"\n\n'
108 return cpp_code_top + self._with_namespace(cpp_code)
110 def _macros(self) -> str:
111 file_name = self.output_file.name
113 def get_macro(name: str) -> str:
114 macro_name = f"_{name}_ASSERT_TRUE"
115 guard_name = f"NO_REGISTER_{name}_ASSERT"
116 name_space = " " * (38 - len(name))
117 file_name_space = " " * (44 - len(file_name))
118 base = """\
119#ifdef {guard_name}
121#define {macro_name}(expression, message) ((void)0)
123#else // Not {guard_name}.
125// This macro is called by the register code to check for runtime errors.
126#define {macro_name}(expression, message) {name_space}\\
127 {{ \\
128 if (!static_cast<bool>(expression)) {{ \\
129 std::ostringstream diagnostics; \\
130 diagnostics << "{file_name}:" << __LINE__ {file_name_space}\\
131 << ": " << message << "."; \\
132 std::string diagnostic_message = diagnostics.str(); \\
133 m_assertion_handler(&diagnostic_message); \\
134 }} \\
135 }}
137#endif // {guard_name}.
138"""
139 return base.format(
140 guard_name=guard_name,
141 macro_name=macro_name,
142 name=name,
143 name_space=name_space,
144 file_name=file_name,
145 file_name_space=file_name_space,
146 )
148 setter_assert = get_macro(name="SETTER")
149 getter_assert = get_macro(name="GETTER")
150 array_index_assert = get_macro(name="ARRAY_INDEX")
151 return f"""\
152{setter_assert}
153{getter_assert}
154{array_index_assert}
155"""
157 def _register_setter_function(
158 self, register: Register, register_array: RegisterArray | None
159 ) -> str:
160 signature = self._register_setter_function_signature(
161 register=register, register_array=register_array, indent=2
162 )
163 cpp_code = f" void {self._class_name}::{signature} const\n"
164 cpp_code += " {\n"
166 if register_array:
167 cpp_code += f"""\
168 _ARRAY_INDEX_ASSERT_TRUE(
169 array_index < {self.name}::{register_array.name}::array_length,
170 "Got '{register_array.name}' array index out of range: " << array_index
171 );
173"""
174 cpp_code += (
175 f" const size_t index = {register_array.base_index} "
176 f"+ array_index * {len(register_array.registers)} + {register.index};\n"
177 )
178 else:
179 cpp_code += f" const size_t index = {register.index};\n"
181 cpp_code += " m_registers[index] = register_value;\n"
182 cpp_code += " }\n\n"
183 return cpp_code
185 def _field_setter_function(
186 self, register: Register, register_array: RegisterArray | None, field: RegisterField
187 ) -> str:
188 signature = self._field_setter_function_signature(
189 register=register,
190 register_array=register_array,
191 field=field,
192 from_value=False,
193 indent=2,
194 )
196 cpp_code = f" void {self._class_name}::{signature} const\n"
197 cpp_code += " {\n"
199 if self.field_setter_should_read_modify_write(register=register):
200 register_getter_function_name = self._register_getter_function_name(
201 register=register, register_array=register_array
202 )
203 cpp_code += self.comment(
204 comment="Get the current value of other fields by reading register on the bus."
205 )
206 current_register_value = f"{register_getter_function_name}("
207 if register_array:
208 current_register_value += "array_index"
209 current_register_value += ")"
211 else:
212 cpp_code += self.comment(
213 "Set everything except for the field to default when writing the value."
214 )
215 current_register_value = str(register.default_value)
217 cpp_code += f" const uint32_t current_register_value = {current_register_value};\n"
219 signature = self._field_setter_function_name(
220 register=register, register_array=register_array, field=field, from_value=True
221 )
222 cpp_code += (
223 " const uint32_t result_register_value = "
224 f"{signature}(current_register_value, field_value);\n"
225 )
227 register_setter_function_name = self._register_setter_function_name(
228 register=register, register_array=register_array
229 )
230 cpp_code += f" {register_setter_function_name}("
231 if register_array:
232 cpp_code += "array_index, "
233 cpp_code += "result_register_value);\n"
235 cpp_code += " }\n\n"
237 return cpp_code
239 def _field_setter_function_from_value(
240 self, register: Register, register_array: RegisterArray | None, field: RegisterField
241 ) -> str:
242 signature = self._field_setter_function_signature(
243 register=register, register_array=register_array, field=field, from_value=True, indent=2
244 )
246 return f"""\
247 uint32_t {self._class_name}::{signature} const\
248 {
249{self._get_field_shift_and_mask(field=field)}\
250{self._get_field_value_checker(register=register, field=field, setter_or_getter="setter")}\
251 const uint32_t field_value_masked = field_value & mask_at_base;
252 const uint32_t field_value_masked_and_shifted = field_value_masked << shift;
254 const uint32_t mask_shifted_inverse = ~mask_shifted;
255 const uint32_t register_value_masked = register_value & mask_shifted_inverse;
257 const uint32_t result_register_value = register_value_masked | field_value_masked_and_shifted;
259 return result_register_value;
260 }
262"""
264 @staticmethod
265 def _get_field_shift_and_mask(field: RegisterField) -> str:
266 return f"""\
267 const uint32_t shift = {field.base_index}uL;
268 const uint32_t mask_at_base = 0b{"1" * field.width}uL;
269 const uint32_t mask_shifted = mask_at_base << shift;
271"""
273 def _get_field_value_checker(
274 self, register: Register, field: RegisterField, setter_or_getter: str
275 ) -> str:
276 comment = "// Check that field value is within the legal range."
277 assertion = f"_{setter_or_getter.upper()}_ASSERT_TRUE"
279 if isinstance(field, Integer):
280 if (
281 field.min_value == 0
282 and not field.is_signed
283 and self._field_value_type_name(
284 register=register, register_array=None, field=field
285 ).startswith("uint")
286 ):
287 min_value_check = ""
288 else:
289 min_value_check = f"""\
290 {assertion}(
291 field_value >= {field.min_value},
292 "Got '{field.name}' value too small: " << field_value
293 );
294"""
296 return f"""\
297 {comment}
298{min_value_check}\
299 {assertion}(
300 field_value <= {field.max_value},
301 "Got '{field.name}' value too large: " << field_value
302 );
304"""
306 if isinstance(field, (Bit, BitVector)):
307 return f"""\
308 {comment}
309 const uint32_t mask_at_base_inverse = ~mask_at_base;
310 {assertion}(
311 (field_value & mask_at_base_inverse) == 0,
312 "Got '{field.name}' value too many bits used: " << field_value
313 );
315"""
317 return ""
319 def _get_field_getter_value_checker(self, register: Register, field: RegisterField) -> str:
320 if isinstance(field, Integer):
321 return self._get_field_value_checker(
322 register=register, field=field, setter_or_getter="getter"
323 )
325 return ""
327 def _register_getter_function(
328 self, register: Register, register_array: RegisterArray | None
329 ) -> str:
330 signature = self._register_getter_function_signature(
331 register=register, register_array=register_array, indent=2
332 )
333 cpp_code = f" uint32_t {self._class_name}::{signature} const\n"
334 cpp_code += " {\n"
336 if register_array:
337 cpp_code += f"""\
338 _ARRAY_INDEX_ASSERT_TRUE(
339 array_index < {self.name}::{register_array.name}::array_length,
340 "Got '{register_array.name}' array index out of range: " << array_index
341 );
343"""
344 cpp_code += (
345 f" const size_t index = {register_array.base_index} "
346 f"+ array_index * {len(register_array.registers)} + {register.index};\n"
347 )
348 else:
349 cpp_code += f" const size_t index = {register.index};\n"
351 cpp_code += " const uint32_t result = m_registers[index];\n\n"
352 cpp_code += " return result;\n"
353 cpp_code += " }\n\n"
354 return cpp_code
356 def _field_getter_function(
357 self, register: Register, register_array: RegisterArray | None, field: RegisterField
358 ) -> str:
359 signature = self._field_getter_function_signature(
360 register=register,
361 register_array=register_array,
362 field=field,
363 from_value=False,
364 indent=2,
365 )
367 field_type_name = self._field_value_type_name(
368 register=register, register_array=register_array, field=field
369 )
371 cpp_code = f" {field_type_name} {self._class_name}::{signature} const\n"
372 cpp_code += " {\n"
374 register_getter_function_name = self._register_getter_function_name(
375 register=register, register_array=register_array
376 )
378 field_getter_from_value_function_name = self._field_getter_function_name(
379 register=register, register_array=register_array, field=field, from_value=True
380 )
382 cpp_code += f" const uint32_t register_value = {register_getter_function_name}("
383 if register_array:
384 cpp_code += "array_index"
385 cpp_code += ");\n"
387 cpp_code += (
388 f" const {field_type_name} field_value = "
389 f"{field_getter_from_value_function_name}(register_value);\n"
390 )
391 cpp_code += "\n return field_value;\n"
392 cpp_code += " }\n\n"
394 return cpp_code
396 def _field_getter_function_from_value(
397 self, register: Register, register_array: RegisterArray | None, field: RegisterField
398 ) -> str:
399 signature = self._field_getter_function_signature(
400 register=register, register_array=register_array, field=field, from_value=True, indent=2
401 )
403 type_name = self._field_value_type_name(
404 register=register, register_array=register_array, field=field
405 )
407 cpp_code = f"""\
408 {type_name} {self._class_name}::{signature} const
409 {
410{self._get_field_shift_and_mask(field=field)}\
411 const uint32_t result_masked = register_value & mask_shifted;
412 const uint32_t result_shifted = result_masked >> shift;
414 {type_name} field_value;
416"""
418 if type_name == "uint32_t":
419 cpp_code += """\
420 // No casting needed.
421 field_value = result_shifted;
422"""
424 elif isinstance(field, Enumeration):
425 cpp_code += f"""\
426 // "Cast" to the enum type.
427 field_value = {type_name}(result_shifted);
428"""
430 elif isinstance(field, Integer) and field.is_signed:
431 cpp_code += f"""\
432 const {type_name} sign_bit_mask = 1 << {field.width - 1};
434 if (result_shifted & sign_bit_mask)
435 {
436 // Value is to be interpreted as negative.
437 // Sign extend it from the width of the field to the width of the return type.
438 field_value = result_shifted - 2 * sign_bit_mask;
439 }
440 else
441 {
442 // Value is positive.
443 field_value = result_shifted;
444 }
445"""
446 else:
447 raise ValueError(f"Got unexpected field type: {type_name}")
449 cpp_code += f"""
450{self._get_field_getter_value_checker(register=register, field=field)}\
451 return field_value;
452 }
454"""
456 return cpp_code