slonik/src/ir/verify.rs

347 lines
13 KiB
Rust

//! Structural verification for Slonik IR.
use core::fmt;
use crate::ir::{Block, BlockCall, Body, Stmt, StmtData, Type, Value, ValueDef};
/// A verification error.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VerifyError {
msg: String,
}
impl VerifyError {
fn new(msg: impl Into<String>) -> Self {
Self { msg: msg.into() }
}
}
impl fmt::Display for VerifyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.msg)
}
}
impl std::error::Error for VerifyError {}
/// Verifies the structural correctness of `body`.
pub fn verify(body: &Body) -> Result<(), VerifyError> {
Verifier { body }.run()
}
struct Verifier<'a> {
body: &'a Body,
}
impl Verifier<'_> {
fn run(self) -> Result<(), VerifyError> {
self.verify_values()?;
self.verify_blocks()?;
self.verify_stmts()?;
Ok(())
}
fn verify_values(&self) -> Result<(), VerifyError> {
for (value, data) in self.body.values.iter() {
match data.def {
ValueDef::Inst(stmt) => {
self.ensure_valid_stmt(stmt, format!("value {value}"))?;
let result = self.body.stmts[stmt].result().ok_or_else(|| {
VerifyError::new(format!(
"value {value} claims to be defined by statement {stmt}, \
but that statement produces no result"
))
})?;
if result != value {
return Err(VerifyError::new(format!(
"value {value} claims to be defined by statement {stmt}, \
but that statement's result is {result}"
)));
}
}
ValueDef::Param(block, index) => {
self.ensure_valid_block(block, format!("value {value}"))?;
let params = self.body.block_params(block);
let Some(actual) = params.get(index as usize).copied() else {
return Err(VerifyError::new(format!(
"value {value} claims to be block parameter #{index} of {block}, \
but that block has only {} parameter(s)",
params.len()
)));
};
if actual != value {
return Err(VerifyError::new(format!(
"value {value} claims to be block parameter #{index} of {block}, \
but that slot contains {actual}"
)));
}
}
}
if data.ty.is_void() {
return Err(VerifyError::new(format!(
"value {value} has type void, which is not a valid SSA value type"
)));
}
}
Ok(())
}
fn verify_blocks(&self) -> Result<(), VerifyError> {
for block in self.body.blocks() {
let params = self.body.block_params(block);
for (index, &param) in params.iter().enumerate() {
self.ensure_valid_value(param, format!("block {block} parameter list"))?;
let data = self.body.value_data(param);
let expected = ValueDef::Param(block, index as u16);
if data.def != expected {
return Err(VerifyError::new(format!(
"block {block} parameter #{index} is {param}, \
but that value is recorded as {:?}",
data.def
)));
}
}
let stmts = self.body.block_stmts(block);
for (index, &stmt) in stmts.iter().enumerate() {
self.ensure_valid_stmt(stmt, format!("block {block} statement list"))?;
let is_last = index + 1 == stmts.len();
let is_term = self.body.stmt_data(stmt).is_terminator();
if is_term && !is_last {
return Err(VerifyError::new(format!(
"terminator statement {stmt} in {block} is not the final statement"
)));
}
}
}
Ok(())
}
fn verify_stmts(&self) -> Result<(), VerifyError> {
for (stmt, &data) in self.body.stmts.iter() {
match data {
StmtData::IConst { result, .. }
| StmtData::F32Const { result, .. }
| StmtData::F64Const { result, .. }
| StmtData::BConst { result, .. } => {
self.verify_result(stmt, result)?;
}
StmtData::Unary { arg, result, .. } => {
self.ensure_valid_value(arg, format!("statement {stmt} unary operand"))?;
self.verify_result(stmt, result)?;
}
StmtData::Binary {
lhs, rhs, result, ..
} => {
self.ensure_valid_value(lhs, format!("statement {stmt} binary lhs"))?;
self.ensure_valid_value(rhs, format!("statement {stmt} binary rhs"))?;
self.verify_result(stmt, result)?;
}
StmtData::Cast { arg, result, .. } => {
self.ensure_valid_value(arg, format!("statement {stmt} cast operand"))?;
self.verify_result(stmt, result)?;
}
StmtData::Icmp {
lhs, rhs, result, ..
}
| StmtData::Fcmp {
lhs, rhs, result, ..
} => {
self.ensure_valid_value(lhs, format!("statement {stmt} compare lhs"))?;
self.ensure_valid_value(rhs, format!("statement {stmt} compare rhs"))?;
self.verify_result(stmt, result)?;
if self.body.value_type(result) != Type::bool() {
return Err(VerifyError::new(format!(
"comparison statement {stmt} must produce bool, got {}",
self.body.value_type(result)
)));
}
}
StmtData::Select {
cond,
if_true,
if_false,
result,
} => {
self.ensure_valid_value(cond, format!("statement {stmt} select condition"))?;
self.ensure_valid_value(if_true, format!("statement {stmt} select true arm"))?;
self.ensure_valid_value(
if_false,
format!("statement {stmt} select false arm"),
)?;
self.verify_result(stmt, result)?;
if self.body.value_type(cond) != Type::bool() {
return Err(VerifyError::new(format!(
"select statement {stmt} condition must have type bool, got {}",
self.body.value_type(cond)
)));
}
let t_ty = self.body.value_type(if_true);
let f_ty = self.body.value_type(if_false);
let out_ty = self.body.value_type(result);
if t_ty != f_ty {
return Err(VerifyError::new(format!(
"select statement {stmt} arms have mismatched types: {t_ty} vs {f_ty}"
)));
}
if out_ty != t_ty {
return Err(VerifyError::new(format!(
"select statement {stmt} result type {out_ty} does not match arm type {t_ty}"
)));
}
}
StmtData::Load { addr, result, .. } => {
self.ensure_valid_value(addr, format!("statement {stmt} load address"))?;
self.verify_result(stmt, result)?;
}
StmtData::Store { addr, value, .. } => {
self.ensure_valid_value(addr, format!("statement {stmt} store address"))?;
self.ensure_valid_value(value, format!("statement {stmt} store value"))?;
}
StmtData::Call { callee, args } => {
self.ensure_valid_value(callee, format!("statement {stmt} callee"))?;
self.verify_value_list(
args.as_slice(&self.body.value_lists),
format!("statement {stmt} call arguments"),
)?;
}
StmtData::Jump { dst } => {
self.verify_block_call(dst, format!("statement {stmt} jump target"))?;
}
StmtData::BrIf {
cond,
then_dst,
else_dst,
} => {
self.ensure_valid_value(cond, format!("statement {stmt} branch condition"))?;
if self.body.value_type(cond) != Type::bool() {
return Err(VerifyError::new(format!(
"statement {stmt} branch condition must have type bool, got {}",
self.body.value_type(cond)
)));
}
self.verify_block_call(then_dst, format!("statement {stmt} then-target"))?;
self.verify_block_call(else_dst, format!("statement {stmt} else-target"))?;
}
StmtData::Return { values } => {
self.verify_value_list(
values.as_slice(&self.body.value_lists),
format!("statement {stmt} return values"),
)?;
}
}
}
Ok(())
}
fn verify_result(&self, stmt: Stmt, result: Value) -> Result<(), VerifyError> {
self.ensure_valid_value(result, format!("statement {stmt} result"))?;
let def = self.body.value_def(result);
if def != ValueDef::Inst(stmt) {
return Err(VerifyError::new(format!(
"statement {stmt} result {result} is recorded as {:?} instead of Inst({stmt})",
def
)));
}
Ok(())
}
fn verify_block_call(&self, call: BlockCall, context: String) -> Result<(), VerifyError> {
self.ensure_valid_block(call.block, context.clone())?;
let args = call.args.as_slice(&self.body.value_lists);
let params = self.body.block_params(call.block);
if args.len() != params.len() {
return Err(VerifyError::new(format!(
"{context} passes {} argument(s) to {}, but that block expects {} parameter(s)",
args.len(),
call.block,
params.len()
)));
}
for (index, (&arg, &param)) in args.iter().zip(params.iter()).enumerate() {
self.ensure_valid_value(arg, format!("{context} argument #{index}"))?;
let arg_ty = self.body.value_type(arg);
let param_ty = self.body.value_type(param);
if arg_ty != param_ty {
return Err(VerifyError::new(format!(
"{context} argument #{index} to {} has type {}, \
but destination parameter has type {}",
call.block, arg_ty, param_ty
)));
}
}
Ok(())
}
fn verify_value_list(&self, values: &[Value], context: String) -> Result<(), VerifyError> {
for (index, &value) in values.iter().enumerate() {
self.ensure_valid_value(value, format!("{context} #{index}"))?;
}
Ok(())
}
fn ensure_valid_block(&self, block: Block, context: String) -> Result<(), VerifyError> {
if !self.body.blocks.is_valid(block) {
return Err(VerifyError::new(format!(
"{context} references invalid block {block}"
)));
}
Ok(())
}
fn ensure_valid_stmt(&self, stmt: Stmt, context: String) -> Result<(), VerifyError> {
if !self.body.stmts.is_valid(stmt) {
return Err(VerifyError::new(format!(
"{context} references invalid statement {stmt}"
)));
}
Ok(())
}
fn ensure_valid_value(&self, value: Value, context: String) -> Result<(), VerifyError> {
if !self.body.values.is_valid(value) {
return Err(VerifyError::new(format!(
"{context} references invalid value {value}"
)));
}
Ok(())
}
}