347 lines
13 KiB
Rust
347 lines
13 KiB
Rust
//! Structural verification for Slonik IR.
|
|
|
|
use core::fmt;
|
|
|
|
use crate::ir::{Block, BlockCall, Body, Stmt, StmtData, Type, Value, ValueDef};
|
|
|
|
/// A verification error.
|
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
|
pub struct VerifyError {
|
|
msg: String,
|
|
}
|
|
|
|
impl VerifyError {
|
|
fn new(msg: impl Into<String>) -> Self {
|
|
Self { msg: msg.into() }
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for VerifyError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.write_str(&self.msg)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for VerifyError {}
|
|
|
|
/// Verifies the structural correctness of `body`.
|
|
pub fn verify(body: &Body) -> Result<(), VerifyError> {
|
|
Verifier { body }.run()
|
|
}
|
|
|
|
struct Verifier<'a> {
|
|
body: &'a Body,
|
|
}
|
|
|
|
impl Verifier<'_> {
|
|
fn run(self) -> Result<(), VerifyError> {
|
|
self.verify_values()?;
|
|
self.verify_blocks()?;
|
|
self.verify_stmts()?;
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_values(&self) -> Result<(), VerifyError> {
|
|
for (value, data) in self.body.values.iter() {
|
|
match data.def {
|
|
ValueDef::Inst(stmt) => {
|
|
self.ensure_valid_stmt(stmt, format!("value {value}"))?;
|
|
|
|
let result = self.body.stmts[stmt].result().ok_or_else(|| {
|
|
VerifyError::new(format!(
|
|
"value {value} claims to be defined by statement {stmt}, \
|
|
but that statement produces no result"
|
|
))
|
|
})?;
|
|
|
|
if result != value {
|
|
return Err(VerifyError::new(format!(
|
|
"value {value} claims to be defined by statement {stmt}, \
|
|
but that statement's result is {result}"
|
|
)));
|
|
}
|
|
}
|
|
|
|
ValueDef::Param(block, index) => {
|
|
self.ensure_valid_block(block, format!("value {value}"))?;
|
|
|
|
let params = self.body.block_params(block);
|
|
let Some(actual) = params.get(index as usize).copied() else {
|
|
return Err(VerifyError::new(format!(
|
|
"value {value} claims to be block parameter #{index} of {block}, \
|
|
but that block has only {} parameter(s)",
|
|
params.len()
|
|
)));
|
|
};
|
|
|
|
if actual != value {
|
|
return Err(VerifyError::new(format!(
|
|
"value {value} claims to be block parameter #{index} of {block}, \
|
|
but that slot contains {actual}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
if data.ty.is_void() {
|
|
return Err(VerifyError::new(format!(
|
|
"value {value} has type void, which is not a valid SSA value type"
|
|
)));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_blocks(&self) -> Result<(), VerifyError> {
|
|
for block in self.body.blocks() {
|
|
let params = self.body.block_params(block);
|
|
for (index, ¶m) in params.iter().enumerate() {
|
|
self.ensure_valid_value(param, format!("block {block} parameter list"))?;
|
|
|
|
let data = self.body.value_data(param);
|
|
let expected = ValueDef::Param(block, index as u16);
|
|
if data.def != expected {
|
|
return Err(VerifyError::new(format!(
|
|
"block {block} parameter #{index} is {param}, \
|
|
but that value is recorded as {:?}",
|
|
data.def
|
|
)));
|
|
}
|
|
}
|
|
|
|
let stmts = self.body.block_stmts(block);
|
|
for (index, &stmt) in stmts.iter().enumerate() {
|
|
self.ensure_valid_stmt(stmt, format!("block {block} statement list"))?;
|
|
|
|
let is_last = index + 1 == stmts.len();
|
|
let is_term = self.body.stmt_data(stmt).is_terminator();
|
|
|
|
if is_term && !is_last {
|
|
return Err(VerifyError::new(format!(
|
|
"terminator statement {stmt} in {block} is not the final statement"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_stmts(&self) -> Result<(), VerifyError> {
|
|
for (stmt, &data) in self.body.stmts.iter() {
|
|
match data {
|
|
StmtData::IConst { result, .. }
|
|
| StmtData::F32Const { result, .. }
|
|
| StmtData::F64Const { result, .. }
|
|
| StmtData::BConst { result, .. } => {
|
|
self.verify_result(stmt, result)?;
|
|
}
|
|
|
|
StmtData::Unary { arg, result, .. } => {
|
|
self.ensure_valid_value(arg, format!("statement {stmt} unary operand"))?;
|
|
self.verify_result(stmt, result)?;
|
|
}
|
|
|
|
StmtData::Binary {
|
|
lhs, rhs, result, ..
|
|
} => {
|
|
self.ensure_valid_value(lhs, format!("statement {stmt} binary lhs"))?;
|
|
self.ensure_valid_value(rhs, format!("statement {stmt} binary rhs"))?;
|
|
self.verify_result(stmt, result)?;
|
|
}
|
|
|
|
StmtData::Cast { arg, result, .. } => {
|
|
self.ensure_valid_value(arg, format!("statement {stmt} cast operand"))?;
|
|
self.verify_result(stmt, result)?;
|
|
}
|
|
|
|
StmtData::Icmp {
|
|
lhs, rhs, result, ..
|
|
}
|
|
| StmtData::Fcmp {
|
|
lhs, rhs, result, ..
|
|
} => {
|
|
self.ensure_valid_value(lhs, format!("statement {stmt} compare lhs"))?;
|
|
self.ensure_valid_value(rhs, format!("statement {stmt} compare rhs"))?;
|
|
self.verify_result(stmt, result)?;
|
|
|
|
if self.body.value_type(result) != Type::bool() {
|
|
return Err(VerifyError::new(format!(
|
|
"comparison statement {stmt} must produce bool, got {}",
|
|
self.body.value_type(result)
|
|
)));
|
|
}
|
|
}
|
|
|
|
StmtData::Select {
|
|
cond,
|
|
if_true,
|
|
if_false,
|
|
result,
|
|
} => {
|
|
self.ensure_valid_value(cond, format!("statement {stmt} select condition"))?;
|
|
self.ensure_valid_value(if_true, format!("statement {stmt} select true arm"))?;
|
|
self.ensure_valid_value(
|
|
if_false,
|
|
format!("statement {stmt} select false arm"),
|
|
)?;
|
|
self.verify_result(stmt, result)?;
|
|
|
|
if self.body.value_type(cond) != Type::bool() {
|
|
return Err(VerifyError::new(format!(
|
|
"select statement {stmt} condition must have type bool, got {}",
|
|
self.body.value_type(cond)
|
|
)));
|
|
}
|
|
|
|
let t_ty = self.body.value_type(if_true);
|
|
let f_ty = self.body.value_type(if_false);
|
|
let out_ty = self.body.value_type(result);
|
|
|
|
if t_ty != f_ty {
|
|
return Err(VerifyError::new(format!(
|
|
"select statement {stmt} arms have mismatched types: {t_ty} vs {f_ty}"
|
|
)));
|
|
}
|
|
|
|
if out_ty != t_ty {
|
|
return Err(VerifyError::new(format!(
|
|
"select statement {stmt} result type {out_ty} does not match arm type {t_ty}"
|
|
)));
|
|
}
|
|
}
|
|
|
|
StmtData::Load { addr, result, .. } => {
|
|
self.ensure_valid_value(addr, format!("statement {stmt} load address"))?;
|
|
self.verify_result(stmt, result)?;
|
|
}
|
|
|
|
StmtData::Store { addr, value, .. } => {
|
|
self.ensure_valid_value(addr, format!("statement {stmt} store address"))?;
|
|
self.ensure_valid_value(value, format!("statement {stmt} store value"))?;
|
|
}
|
|
|
|
StmtData::Call { callee, args } => {
|
|
self.ensure_valid_value(callee, format!("statement {stmt} callee"))?;
|
|
self.verify_value_list(
|
|
args.as_slice(&self.body.value_lists),
|
|
format!("statement {stmt} call arguments"),
|
|
)?;
|
|
}
|
|
|
|
StmtData::Jump { dst } => {
|
|
self.verify_block_call(dst, format!("statement {stmt} jump target"))?;
|
|
}
|
|
|
|
StmtData::BrIf {
|
|
cond,
|
|
then_dst,
|
|
else_dst,
|
|
} => {
|
|
self.ensure_valid_value(cond, format!("statement {stmt} branch condition"))?;
|
|
|
|
if self.body.value_type(cond) != Type::bool() {
|
|
return Err(VerifyError::new(format!(
|
|
"statement {stmt} branch condition must have type bool, got {}",
|
|
self.body.value_type(cond)
|
|
)));
|
|
}
|
|
|
|
self.verify_block_call(then_dst, format!("statement {stmt} then-target"))?;
|
|
self.verify_block_call(else_dst, format!("statement {stmt} else-target"))?;
|
|
}
|
|
|
|
StmtData::Return { values } => {
|
|
self.verify_value_list(
|
|
values.as_slice(&self.body.value_lists),
|
|
format!("statement {stmt} return values"),
|
|
)?;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_result(&self, stmt: Stmt, result: Value) -> Result<(), VerifyError> {
|
|
self.ensure_valid_value(result, format!("statement {stmt} result"))?;
|
|
|
|
let def = self.body.value_def(result);
|
|
if def != ValueDef::Inst(stmt) {
|
|
return Err(VerifyError::new(format!(
|
|
"statement {stmt} result {result} is recorded as {:?} instead of Inst({stmt})",
|
|
def
|
|
)));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_block_call(&self, call: BlockCall, context: String) -> Result<(), VerifyError> {
|
|
self.ensure_valid_block(call.block, context.clone())?;
|
|
|
|
let args = call.args.as_slice(&self.body.value_lists);
|
|
let params = self.body.block_params(call.block);
|
|
|
|
if args.len() != params.len() {
|
|
return Err(VerifyError::new(format!(
|
|
"{context} passes {} argument(s) to {}, but that block expects {} parameter(s)",
|
|
args.len(),
|
|
call.block,
|
|
params.len()
|
|
)));
|
|
}
|
|
|
|
for (index, (&arg, ¶m)) in args.iter().zip(params.iter()).enumerate() {
|
|
self.ensure_valid_value(arg, format!("{context} argument #{index}"))?;
|
|
|
|
let arg_ty = self.body.value_type(arg);
|
|
let param_ty = self.body.value_type(param);
|
|
|
|
if arg_ty != param_ty {
|
|
return Err(VerifyError::new(format!(
|
|
"{context} argument #{index} to {} has type {}, \
|
|
but destination parameter has type {}",
|
|
call.block, arg_ty, param_ty
|
|
)));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_value_list(&self, values: &[Value], context: String) -> Result<(), VerifyError> {
|
|
for (index, &value) in values.iter().enumerate() {
|
|
self.ensure_valid_value(value, format!("{context} #{index}"))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn ensure_valid_block(&self, block: Block, context: String) -> Result<(), VerifyError> {
|
|
if !self.body.blocks.is_valid(block) {
|
|
return Err(VerifyError::new(format!(
|
|
"{context} references invalid block {block}"
|
|
)));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn ensure_valid_stmt(&self, stmt: Stmt, context: String) -> Result<(), VerifyError> {
|
|
if !self.body.stmts.is_valid(stmt) {
|
|
return Err(VerifyError::new(format!(
|
|
"{context} references invalid statement {stmt}"
|
|
)));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn ensure_valid_value(&self, value: Value, context: String) -> Result<(), VerifyError> {
|
|
if !self.body.values.is_valid(value) {
|
|
return Err(VerifyError::new(format!(
|
|
"{context} references invalid value {value}"
|
|
)));
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|