//! Structural verification for Slonik IR. use core::fmt; use crate::ir::{Block, BlockCall, Body, Stmt, StmtData, Type, Value, ValueDef}; /// A verification error. #[derive(Clone, Debug, PartialEq, Eq)] pub struct VerifyError { msg: String, } impl VerifyError { fn new(msg: impl Into) -> Self { Self { msg: msg.into() } } } impl fmt::Display for VerifyError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(&self.msg) } } impl std::error::Error for VerifyError {} /// Verifies the structural correctness of `body`. pub fn verify(body: &Body) -> Result<(), VerifyError> { Verifier { body }.run() } struct Verifier<'a> { body: &'a Body, } impl Verifier<'_> { fn run(self) -> Result<(), VerifyError> { self.verify_values()?; self.verify_blocks()?; self.verify_stmts()?; Ok(()) } fn verify_values(&self) -> Result<(), VerifyError> { for (value, data) in self.body.values.iter() { match data.def { ValueDef::Inst(stmt) => { self.ensure_valid_stmt(stmt, format!("value {value}"))?; let result = self.body.stmts[stmt].result().ok_or_else(|| { VerifyError::new(format!( "value {value} claims to be defined by statement {stmt}, \ but that statement produces no result" )) })?; if result != value { return Err(VerifyError::new(format!( "value {value} claims to be defined by statement {stmt}, \ but that statement's result is {result}" ))); } } ValueDef::Param(block, index) => { self.ensure_valid_block(block, format!("value {value}"))?; let params = self.body.block_params(block); let Some(actual) = params.get(index as usize).copied() else { return Err(VerifyError::new(format!( "value {value} claims to be block parameter #{index} of {block}, \ but that block has only {} parameter(s)", params.len() ))); }; if actual != value { return Err(VerifyError::new(format!( "value {value} claims to be block parameter #{index} of {block}, \ but that slot contains {actual}" ))); } } } if data.ty.is_void() { return Err(VerifyError::new(format!( "value {value} has type void, which is not a valid SSA value type" ))); } } Ok(()) } fn verify_blocks(&self) -> Result<(), VerifyError> { for block in self.body.blocks() { let params = self.body.block_params(block); for (index, ¶m) in params.iter().enumerate() { self.ensure_valid_value(param, format!("block {block} parameter list"))?; let data = self.body.value_data(param); let expected = ValueDef::Param(block, index as u16); if data.def != expected { return Err(VerifyError::new(format!( "block {block} parameter #{index} is {param}, \ but that value is recorded as {:?}", data.def ))); } } let stmts = self.body.block_stmts(block); for (index, &stmt) in stmts.iter().enumerate() { self.ensure_valid_stmt(stmt, format!("block {block} statement list"))?; let is_last = index + 1 == stmts.len(); let is_term = self.body.stmt_data(stmt).is_terminator(); if is_term && !is_last { return Err(VerifyError::new(format!( "terminator statement {stmt} in {block} is not the final statement" ))); } } } Ok(()) } fn verify_stmts(&self) -> Result<(), VerifyError> { for (stmt, &data) in self.body.stmts.iter() { match data { StmtData::IConst { result, .. } | StmtData::F32Const { result, .. } | StmtData::F64Const { result, .. } | StmtData::BConst { result, .. } => { self.verify_result(stmt, result)?; } StmtData::Unary { arg, result, .. } => { self.ensure_valid_value(arg, format!("statement {stmt} unary operand"))?; self.verify_result(stmt, result)?; } StmtData::Binary { lhs, rhs, result, .. } => { self.ensure_valid_value(lhs, format!("statement {stmt} binary lhs"))?; self.ensure_valid_value(rhs, format!("statement {stmt} binary rhs"))?; self.verify_result(stmt, result)?; } StmtData::Cast { arg, result, .. } => { self.ensure_valid_value(arg, format!("statement {stmt} cast operand"))?; self.verify_result(stmt, result)?; } StmtData::Icmp { lhs, rhs, result, .. } | StmtData::Fcmp { lhs, rhs, result, .. } => { self.ensure_valid_value(lhs, format!("statement {stmt} compare lhs"))?; self.ensure_valid_value(rhs, format!("statement {stmt} compare rhs"))?; self.verify_result(stmt, result)?; if self.body.value_type(result) != Type::bool() { return Err(VerifyError::new(format!( "comparison statement {stmt} must produce bool, got {}", self.body.value_type(result) ))); } } StmtData::Select { cond, if_true, if_false, result, } => { self.ensure_valid_value(cond, format!("statement {stmt} select condition"))?; self.ensure_valid_value(if_true, format!("statement {stmt} select true arm"))?; self.ensure_valid_value( if_false, format!("statement {stmt} select false arm"), )?; self.verify_result(stmt, result)?; if self.body.value_type(cond) != Type::bool() { return Err(VerifyError::new(format!( "select statement {stmt} condition must have type bool, got {}", self.body.value_type(cond) ))); } let t_ty = self.body.value_type(if_true); let f_ty = self.body.value_type(if_false); let out_ty = self.body.value_type(result); if t_ty != f_ty { return Err(VerifyError::new(format!( "select statement {stmt} arms have mismatched types: {t_ty} vs {f_ty}" ))); } if out_ty != t_ty { return Err(VerifyError::new(format!( "select statement {stmt} result type {out_ty} does not match arm type {t_ty}" ))); } } StmtData::Load { addr, result, .. } => { self.ensure_valid_value(addr, format!("statement {stmt} load address"))?; self.verify_result(stmt, result)?; } StmtData::Store { addr, value, .. } => { self.ensure_valid_value(addr, format!("statement {stmt} store address"))?; self.ensure_valid_value(value, format!("statement {stmt} store value"))?; } StmtData::Call { callee, args } => { self.ensure_valid_value(callee, format!("statement {stmt} callee"))?; self.verify_value_list( args.as_slice(&self.body.value_lists), format!("statement {stmt} call arguments"), )?; } StmtData::Jump { dst } => { self.verify_block_call(dst, format!("statement {stmt} jump target"))?; } StmtData::BrIf { cond, then_dst, else_dst, } => { self.ensure_valid_value(cond, format!("statement {stmt} branch condition"))?; if self.body.value_type(cond) != Type::bool() { return Err(VerifyError::new(format!( "statement {stmt} branch condition must have type bool, got {}", self.body.value_type(cond) ))); } self.verify_block_call(then_dst, format!("statement {stmt} then-target"))?; self.verify_block_call(else_dst, format!("statement {stmt} else-target"))?; } StmtData::Return { values } => { self.verify_value_list( values.as_slice(&self.body.value_lists), format!("statement {stmt} return values"), )?; } } } Ok(()) } fn verify_result(&self, stmt: Stmt, result: Value) -> Result<(), VerifyError> { self.ensure_valid_value(result, format!("statement {stmt} result"))?; let def = self.body.value_def(result); if def != ValueDef::Inst(stmt) { return Err(VerifyError::new(format!( "statement {stmt} result {result} is recorded as {:?} instead of Inst({stmt})", def ))); } Ok(()) } fn verify_block_call(&self, call: BlockCall, context: String) -> Result<(), VerifyError> { self.ensure_valid_block(call.block, context.clone())?; let args = call.args.as_slice(&self.body.value_lists); let params = self.body.block_params(call.block); if args.len() != params.len() { return Err(VerifyError::new(format!( "{context} passes {} argument(s) to {}, but that block expects {} parameter(s)", args.len(), call.block, params.len() ))); } for (index, (&arg, ¶m)) in args.iter().zip(params.iter()).enumerate() { self.ensure_valid_value(arg, format!("{context} argument #{index}"))?; let arg_ty = self.body.value_type(arg); let param_ty = self.body.value_type(param); if arg_ty != param_ty { return Err(VerifyError::new(format!( "{context} argument #{index} to {} has type {}, \ but destination parameter has type {}", call.block, arg_ty, param_ty ))); } } Ok(()) } fn verify_value_list(&self, values: &[Value], context: String) -> Result<(), VerifyError> { for (index, &value) in values.iter().enumerate() { self.ensure_valid_value(value, format!("{context} #{index}"))?; } Ok(()) } fn ensure_valid_block(&self, block: Block, context: String) -> Result<(), VerifyError> { if !self.body.blocks.is_valid(block) { return Err(VerifyError::new(format!( "{context} references invalid block {block}" ))); } Ok(()) } fn ensure_valid_stmt(&self, stmt: Stmt, context: String) -> Result<(), VerifyError> { if !self.body.stmts.is_valid(stmt) { return Err(VerifyError::new(format!( "{context} references invalid statement {stmt}" ))); } Ok(()) } fn ensure_valid_value(&self, value: Value, context: String) -> Result<(), VerifyError> { if !self.body.values.is_valid(value) { return Err(VerifyError::new(format!( "{context} references invalid value {value}" ))); } Ok(()) } }