initial commit
This commit is contained in:
commit
86e4751dcc
7 changed files with 5007 additions and 0 deletions
856
src/dice.rs
Normal file
856
src/dice.rs
Normal file
|
|
@ -0,0 +1,856 @@
|
|||
//! Dice DSL parser and evaluator.
|
||||
//!
|
||||
//! Grammar (top-level):
|
||||
//!
|
||||
//! ```text
|
||||
//! input = roll (';' roll)*
|
||||
//! roll = expr (cmp NUM)? ('[' LABEL ']')?
|
||||
//! cmp = '<=' | '>=' | '<' | '>' | '==' | '!='
|
||||
//! expr = term (('+'|'-') term)*
|
||||
//! term = unary (('*'|'/') unary)*
|
||||
//! unary = '-' unary | atom
|
||||
//! atom = '(' expr ')' | NUM | [N] 'd' (NUM | '%') (KEEP)?
|
||||
//! keep = ('kh'|'kl') NUM
|
||||
//! ```
|
||||
//!
|
||||
//! `d%` is shorthand for `d100`. Labels are any text not containing `]`.
|
||||
|
||||
use rand::Rng;
|
||||
use std::fmt::{self, Write};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Expr {
|
||||
Num(i64),
|
||||
Dice {
|
||||
count: u32,
|
||||
sides: u32,
|
||||
keep: Option<Keep>,
|
||||
},
|
||||
BinOp(Op, Box<Expr>, Box<Expr>),
|
||||
Neg(Box<Expr>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Keep {
|
||||
High(u32),
|
||||
Low(u32),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Op {
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum CmpOp {
|
||||
Le,
|
||||
Ge,
|
||||
Lt,
|
||||
Gt,
|
||||
Eq,
|
||||
Ne,
|
||||
}
|
||||
|
||||
impl CmpOp {
|
||||
fn apply(self, lhs: i64, rhs: i64) -> bool {
|
||||
match self {
|
||||
CmpOp::Le => lhs <= rhs,
|
||||
CmpOp::Ge => lhs >= rhs,
|
||||
CmpOp::Lt => lhs < rhs,
|
||||
CmpOp::Gt => lhs > rhs,
|
||||
CmpOp::Eq => lhs == rhs,
|
||||
CmpOp::Ne => lhs != rhs,
|
||||
}
|
||||
}
|
||||
|
||||
fn symbol(self) -> &'static str {
|
||||
match self {
|
||||
CmpOp::Le => "≤",
|
||||
CmpOp::Ge => "≥",
|
||||
CmpOp::Lt => "<",
|
||||
CmpOp::Gt => ">",
|
||||
CmpOp::Eq => "=",
|
||||
CmpOp::Ne => "≠",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct RollSpec {
|
||||
pub expr: Expr,
|
||||
pub cmp: Option<(CmpOp, i64)>,
|
||||
pub label: Option<String>,
|
||||
/// Original input slice for echoing in replies (trimmed, without the
|
||||
/// optional `[label]` suffix).
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
// ---------- parser ----------
|
||||
|
||||
#[derive(Debug, thiserror::Error, PartialEq)]
|
||||
pub enum ParseError {
|
||||
#[error("expected {expected} at position {pos}, got {got}")]
|
||||
Expected {
|
||||
expected: &'static str,
|
||||
pos: usize,
|
||||
got: String,
|
||||
},
|
||||
#[error("number out of range")]
|
||||
NumberRange,
|
||||
#[error("dice count must be at least 1")]
|
||||
ZeroDice,
|
||||
#[error("dice must have at least 1 side")]
|
||||
ZeroSides,
|
||||
#[error("keep count {keep} exceeds dice count {count}")]
|
||||
KeepTooLarge { keep: u32, count: u32 },
|
||||
#[error("trailing input at position {0}")]
|
||||
Trailing(usize),
|
||||
#[error("unterminated label (missing ']')")]
|
||||
UnterminatedLabel,
|
||||
#[error("empty roll")]
|
||||
Empty,
|
||||
}
|
||||
|
||||
struct Parser<'a> {
|
||||
src: &'a [u8],
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
/// Parse a single roll spec from one segment (no `;` splitting).
|
||||
pub fn parse(input: &str) -> Result<RollSpec, ParseError> {
|
||||
let trimmed = input.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(ParseError::Empty);
|
||||
}
|
||||
let mut p = Parser {
|
||||
src: trimmed.as_bytes(),
|
||||
pos: 0,
|
||||
};
|
||||
p.skip_ws();
|
||||
let expr = p.parse_expr()?;
|
||||
p.skip_ws();
|
||||
|
||||
let cmp = p.parse_cmp()?;
|
||||
p.skip_ws();
|
||||
|
||||
let label = p.parse_label()?;
|
||||
p.skip_ws();
|
||||
|
||||
if p.pos != p.src.len() {
|
||||
return Err(ParseError::Trailing(p.pos));
|
||||
}
|
||||
|
||||
// source = trimmed minus the [label] suffix, retrimmed
|
||||
let source = match &label {
|
||||
Some(_) => {
|
||||
// find the '[' by scanning backwards for the last bracket
|
||||
let bytes = trimmed.as_bytes();
|
||||
let mut bracket = None;
|
||||
for (i, b) in bytes.iter().enumerate().rev() {
|
||||
if *b == b'[' {
|
||||
bracket = Some(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
bracket
|
||||
.map(|i| trimmed[..i].trim_end().to_string())
|
||||
.unwrap_or_else(|| trimmed.to_string())
|
||||
}
|
||||
None => trimmed.to_string(),
|
||||
};
|
||||
|
||||
Ok(RollSpec {
|
||||
expr,
|
||||
cmp,
|
||||
label,
|
||||
source,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a full `input` that may contain multiple `;`-separated rolls.
|
||||
/// Returns one Result per segment, in order.
|
||||
pub fn parse_many(input: &str) -> Vec<Result<RollSpec, ParseError>> {
|
||||
input.split(';').map(parse).collect()
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn peek(&self) -> Option<u8> {
|
||||
self.src.get(self.pos).copied()
|
||||
}
|
||||
|
||||
fn peek2(&self) -> Option<u8> {
|
||||
self.src.get(self.pos + 1).copied()
|
||||
}
|
||||
|
||||
fn skip_ws(&mut self) {
|
||||
while let Some(b) = self.peek() {
|
||||
if b.is_ascii_whitespace() {
|
||||
self.pos += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eat(&mut self, c: u8) -> bool {
|
||||
if self.peek() == Some(c) {
|
||||
self.pos += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn eat_ci(&mut self, lit: &str) -> bool {
|
||||
let bytes = lit.as_bytes();
|
||||
if self.src.len() - self.pos < bytes.len() {
|
||||
return false;
|
||||
}
|
||||
if self.src[self.pos..self.pos + bytes.len()]
|
||||
.iter()
|
||||
.zip(bytes)
|
||||
.all(|(a, b)| a.eq_ignore_ascii_case(b))
|
||||
{
|
||||
self.pos += bytes.len();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_number(&mut self) -> Result<u32, ParseError> {
|
||||
let start = self.pos;
|
||||
while let Some(b) = self.peek() {
|
||||
if b.is_ascii_digit() {
|
||||
self.pos += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if start == self.pos {
|
||||
return Err(ParseError::Expected {
|
||||
expected: "number",
|
||||
pos: start,
|
||||
got: self.snippet(),
|
||||
});
|
||||
}
|
||||
std::str::from_utf8(&self.src[start..self.pos])
|
||||
.unwrap()
|
||||
.parse::<u32>()
|
||||
.map_err(|_| ParseError::NumberRange)
|
||||
}
|
||||
|
||||
fn parse_signed_number(&mut self) -> Result<i64, ParseError> {
|
||||
let negate = self.eat(b'-');
|
||||
self.skip_ws();
|
||||
let n = self.parse_number()? as i64;
|
||||
Ok(if negate { -n } else { n })
|
||||
}
|
||||
|
||||
fn snippet(&self) -> String {
|
||||
match self.peek() {
|
||||
Some(b) => format!("{:?}", b as char),
|
||||
None => "end of input".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_expr(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut lhs = self.parse_term()?;
|
||||
loop {
|
||||
self.skip_ws();
|
||||
let op = match self.peek() {
|
||||
Some(b'+') => Op::Add,
|
||||
Some(b'-') => Op::Sub,
|
||||
_ => break,
|
||||
};
|
||||
self.pos += 1;
|
||||
self.skip_ws();
|
||||
let rhs = self.parse_term()?;
|
||||
lhs = Expr::BinOp(op, Box::new(lhs), Box::new(rhs));
|
||||
}
|
||||
Ok(lhs)
|
||||
}
|
||||
|
||||
fn parse_term(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut lhs = self.parse_unary()?;
|
||||
loop {
|
||||
self.skip_ws();
|
||||
let op = match self.peek() {
|
||||
Some(b'*') => Op::Mul,
|
||||
Some(b'/') => Op::Div,
|
||||
_ => break,
|
||||
};
|
||||
self.pos += 1;
|
||||
self.skip_ws();
|
||||
let rhs = self.parse_unary()?;
|
||||
lhs = Expr::BinOp(op, Box::new(lhs), Box::new(rhs));
|
||||
}
|
||||
Ok(lhs)
|
||||
}
|
||||
|
||||
fn parse_unary(&mut self) -> Result<Expr, ParseError> {
|
||||
self.skip_ws();
|
||||
if self.eat(b'-') {
|
||||
let inner = self.parse_unary()?;
|
||||
return Ok(Expr::Neg(Box::new(inner)));
|
||||
}
|
||||
self.parse_atom()
|
||||
}
|
||||
|
||||
fn parse_atom(&mut self) -> Result<Expr, ParseError> {
|
||||
self.skip_ws();
|
||||
if self.eat(b'(') {
|
||||
let e = self.parse_expr()?;
|
||||
self.skip_ws();
|
||||
if !self.eat(b')') {
|
||||
return Err(ParseError::Expected {
|
||||
expected: "')'",
|
||||
pos: self.pos,
|
||||
got: self.snippet(),
|
||||
});
|
||||
}
|
||||
return Ok(e);
|
||||
}
|
||||
|
||||
let had_digits = matches!(self.peek(), Some(b) if b.is_ascii_digit());
|
||||
let leading = if had_digits { self.parse_number()? } else { 1 };
|
||||
|
||||
if matches!(self.peek(), Some(b'd' | b'D')) {
|
||||
self.pos += 1;
|
||||
// d% shorthand for d100
|
||||
let sides = if self.eat(b'%') {
|
||||
100
|
||||
} else {
|
||||
self.parse_number()?
|
||||
};
|
||||
if leading == 0 {
|
||||
return Err(ParseError::ZeroDice);
|
||||
}
|
||||
if sides == 0 {
|
||||
return Err(ParseError::ZeroSides);
|
||||
}
|
||||
let keep = self.parse_keep()?;
|
||||
if let Some(k) = keep {
|
||||
let keep_n = match k {
|
||||
Keep::High(n) | Keep::Low(n) => n,
|
||||
};
|
||||
if keep_n == 0 {
|
||||
return Err(ParseError::Expected {
|
||||
expected: "non-zero keep count",
|
||||
pos: self.pos,
|
||||
got: "0".into(),
|
||||
});
|
||||
}
|
||||
if keep_n > leading {
|
||||
return Err(ParseError::KeepTooLarge {
|
||||
keep: keep_n,
|
||||
count: leading,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(Expr::Dice {
|
||||
count: leading,
|
||||
sides,
|
||||
keep,
|
||||
})
|
||||
} else if had_digits {
|
||||
Ok(Expr::Num(leading as i64))
|
||||
} else {
|
||||
Err(ParseError::Expected {
|
||||
expected: "number, dice, or '('",
|
||||
pos: self.pos,
|
||||
got: self.snippet(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_keep(&mut self) -> Result<Option<Keep>, ParseError> {
|
||||
if self.eat_ci("kh") {
|
||||
let n = self.parse_number()?;
|
||||
return Ok(Some(Keep::High(n)));
|
||||
}
|
||||
if self.eat_ci("kl") {
|
||||
let n = self.parse_number()?;
|
||||
return Ok(Some(Keep::Low(n)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// Comparison: parses `<=`, `>=`, `<`, `>`, `==`, `!=` followed by an integer.
|
||||
// Note `<` and `>` must be tried *after* `<=` / `>=`.
|
||||
fn parse_cmp(&mut self) -> Result<Option<(CmpOp, i64)>, ParseError> {
|
||||
let op = if self.peek() == Some(b'<') && self.peek2() == Some(b'=') {
|
||||
self.pos += 2;
|
||||
CmpOp::Le
|
||||
} else if self.peek() == Some(b'>') && self.peek2() == Some(b'=') {
|
||||
self.pos += 2;
|
||||
CmpOp::Ge
|
||||
} else if self.peek() == Some(b'=') && self.peek2() == Some(b'=') {
|
||||
self.pos += 2;
|
||||
CmpOp::Eq
|
||||
} else if self.peek() == Some(b'!') && self.peek2() == Some(b'=') {
|
||||
self.pos += 2;
|
||||
CmpOp::Ne
|
||||
} else if self.peek() == Some(b'<') {
|
||||
self.pos += 1;
|
||||
CmpOp::Lt
|
||||
} else if self.peek() == Some(b'>') {
|
||||
self.pos += 1;
|
||||
CmpOp::Gt
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.skip_ws();
|
||||
let n = self.parse_signed_number()?;
|
||||
Ok(Some((op, n)))
|
||||
}
|
||||
|
||||
fn parse_label(&mut self) -> Result<Option<String>, ParseError> {
|
||||
if !self.eat(b'[') {
|
||||
return Ok(None);
|
||||
}
|
||||
let start = self.pos;
|
||||
while let Some(b) = self.peek() {
|
||||
if b == b']' {
|
||||
let label = std::str::from_utf8(&self.src[start..self.pos])
|
||||
.unwrap()
|
||||
.trim()
|
||||
.to_string();
|
||||
self.pos += 1;
|
||||
return Ok(Some(label));
|
||||
}
|
||||
self.pos += 1;
|
||||
}
|
||||
Err(ParseError::UnterminatedLabel)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- evaluator ----------
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EvalError {
|
||||
#[error("division by zero")]
|
||||
DivByZero,
|
||||
#[error("too many dice (limit 1000)")]
|
||||
TooManyDice,
|
||||
#[error("arithmetic overflow")]
|
||||
Overflow,
|
||||
}
|
||||
|
||||
pub const MAX_DICE: u32 = 1000;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Roll {
|
||||
pub source: String,
|
||||
pub label: Option<String>,
|
||||
pub total: i64,
|
||||
/// Human-readable evaluation trace, e.g. `[18, ~~3~~] + 2 = **21**`.
|
||||
pub trace: String,
|
||||
/// If a comparison was supplied, the outcome and threshold.
|
||||
pub check: Option<CheckResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CheckResult {
|
||||
pub op: CmpOp,
|
||||
pub target: i64,
|
||||
pub passed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum DiceError {
|
||||
#[error(transparent)]
|
||||
Parse(#[from] ParseError),
|
||||
#[error(transparent)]
|
||||
Eval(#[from] EvalError),
|
||||
}
|
||||
|
||||
pub fn roll(input: &str, rng: &mut impl Rng) -> Result<Roll, DiceError> {
|
||||
let spec = parse(input)?;
|
||||
eval_spec(&spec, rng)
|
||||
}
|
||||
|
||||
/// Roll multiple `;`-separated expressions. Each result is independent.
|
||||
pub fn roll_many(input: &str, rng: &mut impl Rng) -> Vec<Result<Roll, DiceError>> {
|
||||
parse_many(input)
|
||||
.into_iter()
|
||||
.map(|r| match r {
|
||||
Ok(spec) => eval_spec(&spec, rng),
|
||||
Err(e) => Err(e.into()),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_spec(spec: &RollSpec, rng: &mut impl Rng) -> Result<Roll, DiceError> {
|
||||
let mut trace = String::new();
|
||||
let total = eval(&spec.expr, rng, &mut trace)?;
|
||||
let _ = write!(trace, " = **{}**", total);
|
||||
|
||||
let check = spec.cmp.map(|(op, target)| CheckResult {
|
||||
op,
|
||||
target,
|
||||
passed: op.apply(total, target),
|
||||
});
|
||||
|
||||
Ok(Roll {
|
||||
source: spec.source.clone(),
|
||||
label: spec.label.clone(),
|
||||
total,
|
||||
trace,
|
||||
check,
|
||||
})
|
||||
}
|
||||
|
||||
fn eval(expr: &Expr, rng: &mut impl Rng, out: &mut String) -> Result<i64, EvalError> {
|
||||
match expr {
|
||||
Expr::Num(n) => {
|
||||
let _ = write!(out, "{}", n);
|
||||
Ok(*n)
|
||||
}
|
||||
Expr::Neg(inner) => {
|
||||
out.push('-');
|
||||
let v = eval(inner, rng, out)?;
|
||||
v.checked_neg().ok_or(EvalError::Overflow)
|
||||
}
|
||||
Expr::Dice { count, sides, keep } => roll_dice(*count, *sides, *keep, rng, out),
|
||||
Expr::BinOp(op, a, b) => {
|
||||
let lhs = eval(a, rng, out)?;
|
||||
out.push(' ');
|
||||
out.push(match op {
|
||||
Op::Add => '+',
|
||||
Op::Sub => '-',
|
||||
Op::Mul => '*',
|
||||
Op::Div => '/',
|
||||
});
|
||||
out.push(' ');
|
||||
let rhs = eval(b, rng, out)?;
|
||||
match op {
|
||||
Op::Add => lhs.checked_add(rhs).ok_or(EvalError::Overflow),
|
||||
Op::Sub => lhs.checked_sub(rhs).ok_or(EvalError::Overflow),
|
||||
Op::Mul => lhs.checked_mul(rhs).ok_or(EvalError::Overflow),
|
||||
Op::Div => {
|
||||
if rhs == 0 {
|
||||
Err(EvalError::DivByZero)
|
||||
} else {
|
||||
lhs.checked_div(rhs).ok_or(EvalError::Overflow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn roll_dice(
|
||||
count: u32,
|
||||
sides: u32,
|
||||
keep: Option<Keep>,
|
||||
rng: &mut impl Rng,
|
||||
out: &mut String,
|
||||
) -> Result<i64, EvalError> {
|
||||
if count > MAX_DICE {
|
||||
return Err(EvalError::TooManyDice);
|
||||
}
|
||||
let rolls: Vec<u32> = (0..count).map(|_| rng.gen_range(1..=sides)).collect();
|
||||
|
||||
let kept: Vec<bool> = match keep {
|
||||
None => vec![true; rolls.len()],
|
||||
Some(Keep::High(k)) => mark_kept(&rolls, k, true),
|
||||
Some(Keep::Low(k)) => mark_kept(&rolls, k, false),
|
||||
};
|
||||
|
||||
let total: i64 = rolls
|
||||
.iter()
|
||||
.zip(&kept)
|
||||
.filter(|(_, keep)| **keep)
|
||||
.map(|(v, _)| *v as i64)
|
||||
.sum();
|
||||
|
||||
out.push('[');
|
||||
for (i, (v, keep)) in rolls.iter().zip(&kept).enumerate() {
|
||||
if i > 0 {
|
||||
out.push_str(", ");
|
||||
}
|
||||
if *keep {
|
||||
let _ = write!(out, "{}", v);
|
||||
} else {
|
||||
let _ = write!(out, "~~{}~~", v);
|
||||
}
|
||||
}
|
||||
out.push(']');
|
||||
Ok(total)
|
||||
}
|
||||
|
||||
fn mark_kept(rolls: &[u32], k: u32, high: bool) -> Vec<bool> {
|
||||
let mut indexed: Vec<(usize, u32)> = rolls.iter().copied().enumerate().collect();
|
||||
if high {
|
||||
indexed.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
} else {
|
||||
indexed.sort_by(|a, b| a.1.cmp(&b.1));
|
||||
}
|
||||
let mut kept = vec![false; rolls.len()];
|
||||
for (idx, _) in indexed.into_iter().take(k as usize) {
|
||||
kept[idx] = true;
|
||||
}
|
||||
kept
|
||||
}
|
||||
|
||||
impl fmt::Display for Roll {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// format: `1d20+2 [save]`: [18] + 2 = **20** ✓ (≤ 20)
|
||||
if let Some(label) = &self.label {
|
||||
write!(f, "`{}` [{}]: {}", self.source, label, self.trace)?;
|
||||
} else {
|
||||
write!(f, "`{}`: {}", self.source, self.trace)?;
|
||||
}
|
||||
if let Some(c) = &self.check {
|
||||
let mark = if c.passed { "✓" } else { "✗" };
|
||||
write!(f, " {} ({} {})", mark, c.op.symbol(), c.target)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- tests ----------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::rngs::mock::StepRng;
|
||||
|
||||
fn p(s: &str) -> RollSpec {
|
||||
parse(s).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_literal() {
|
||||
let r = p("42");
|
||||
assert_eq!(r.expr, Expr::Num(42));
|
||||
assert!(r.cmp.is_none());
|
||||
assert!(r.label.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bare_d() {
|
||||
assert_eq!(
|
||||
p("d20").expr,
|
||||
Expr::Dice {
|
||||
count: 1,
|
||||
sides: 20,
|
||||
keep: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ndm() {
|
||||
assert_eq!(
|
||||
p("3d6").expr,
|
||||
Expr::Dice {
|
||||
count: 3,
|
||||
sides: 6,
|
||||
keep: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_percentile_alias() {
|
||||
assert_eq!(
|
||||
p("d%").expr,
|
||||
Expr::Dice {
|
||||
count: 1,
|
||||
sides: 100,
|
||||
keep: None
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
p("1d%").expr,
|
||||
Expr::Dice {
|
||||
count: 1,
|
||||
sides: 100,
|
||||
keep: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_kh() {
|
||||
assert_eq!(
|
||||
p("2d20kh1").expr,
|
||||
Expr::Dice {
|
||||
count: 2,
|
||||
sides: 20,
|
||||
keep: Some(Keep::High(1))
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_kl_case_insensitive() {
|
||||
assert_eq!(
|
||||
p("4D6KL3").expr,
|
||||
Expr::Dice {
|
||||
count: 4,
|
||||
sides: 6,
|
||||
keep: Some(Keep::Low(3))
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_arith() {
|
||||
assert_eq!(
|
||||
p("1d20 + 3").expr,
|
||||
Expr::BinOp(
|
||||
Op::Add,
|
||||
Box::new(Expr::Dice {
|
||||
count: 1,
|
||||
sides: 20,
|
||||
keep: None
|
||||
}),
|
||||
Box::new(Expr::Num(3))
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_comparison() {
|
||||
let r = p("1d20+2 <= 14");
|
||||
assert_eq!(r.cmp, Some((CmpOp::Le, 14)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_comparison_all_ops() {
|
||||
assert_eq!(p("1d6 < 3").cmp, Some((CmpOp::Lt, 3)));
|
||||
assert_eq!(p("1d6 > 3").cmp, Some((CmpOp::Gt, 3)));
|
||||
assert_eq!(p("1d6 <= 3").cmp, Some((CmpOp::Le, 3)));
|
||||
assert_eq!(p("1d6 >= 3").cmp, Some((CmpOp::Ge, 3)));
|
||||
assert_eq!(p("1d6 == 3").cmp, Some((CmpOp::Eq, 3)));
|
||||
assert_eq!(p("1d6 != 3").cmp, Some((CmpOp::Ne, 3)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_negative_target() {
|
||||
assert_eq!(p("1d6-10 <= -2").cmp, Some((CmpOp::Le, -2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_label() {
|
||||
let r = p("1d20+2 [save vs poison]");
|
||||
assert_eq!(r.label.as_deref(), Some("save vs poison"));
|
||||
assert_eq!(r.source, "1d20+2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_label_with_cmp() {
|
||||
let r = p("1d20 <= 13 [save]");
|
||||
assert_eq!(r.label.as_deref(), Some("save"));
|
||||
assert_eq!(r.cmp, Some((CmpOp::Le, 13)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unterminated_label() {
|
||||
assert_eq!(parse("1d20 [oops"), Err(ParseError::UnterminatedLabel));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_many_splits() {
|
||||
let rs = parse_many("1d20+3; 1d8+1 [dmg]; d%");
|
||||
assert_eq!(rs.len(), 3);
|
||||
assert!(rs.iter().all(|r| r.is_ok()));
|
||||
assert_eq!(rs[1].as_ref().unwrap().label.as_deref(), Some("dmg"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_trailing() {
|
||||
assert!(matches!(parse("1d20 foo"), Err(ParseError::Trailing(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_zero_sides() {
|
||||
assert!(matches!(parse("1d0"), Err(ParseError::ZeroSides)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_keep_too_large() {
|
||||
assert!(matches!(
|
||||
parse("2d20kh3"),
|
||||
Err(ParseError::KeepTooLarge { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_paren_precedence() {
|
||||
let mut rng = StepRng::new(0, 0);
|
||||
let r = roll("(1+2)*3", &mut rng).unwrap();
|
||||
assert_eq!(r.total, 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_comparison_passes() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..20 {
|
||||
let r = roll("1d6 <= 6", &mut rng).unwrap();
|
||||
let c = r.check.unwrap();
|
||||
assert!(c.passed);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_comparison_fails() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..20 {
|
||||
let r = roll("1d6 > 6", &mut rng).unwrap();
|
||||
let c = r.check.unwrap();
|
||||
assert!(!c.passed);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_kh_keeps_highest() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..50 {
|
||||
let r = roll("3d6kh2", &mut rng).unwrap();
|
||||
assert!(
|
||||
(2..=12).contains(&r.total),
|
||||
"sum of top 2 d6 must be in [2,12], got {}",
|
||||
r.total
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_div_by_zero() {
|
||||
let mut rng = rand::thread_rng();
|
||||
assert!(matches!(
|
||||
roll("5/0", &mut rng),
|
||||
Err(DiceError::Eval(EvalError::DivByZero))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_too_many_dice() {
|
||||
let mut rng = rand::thread_rng();
|
||||
assert!(matches!(
|
||||
roll("1001d6", &mut rng),
|
||||
Err(DiceError::Eval(EvalError::TooManyDice))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_without_label() {
|
||||
let mut rng = StepRng::new(0, 0);
|
||||
let r = roll("5+3", &mut rng).unwrap();
|
||||
assert_eq!(r.to_string(), "`5+3`: 5 + 3 = **8**");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_with_label_and_check() {
|
||||
let mut rng = StepRng::new(0, 0);
|
||||
let r = roll("5+3 <= 10 [check]", &mut rng).unwrap();
|
||||
assert_eq!(r.to_string(), "`5+3 <= 10` [check]: 5 + 3 = **8** ✓ (≤ 10)");
|
||||
}
|
||||
}
|
||||
165
src/main.rs
Normal file
165
src/main.rs
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
mod dice;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use headjack::{Bot, BotConfig, Login};
|
||||
use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
|
||||
use rand::SeedableRng;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
const COMMAND_PREFIX: &str = "!roll";
|
||||
|
||||
fn env_required(name: &str) -> Result<String> {
|
||||
env::var(name).with_context(|| format!("{name} must be set"))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,headjack=info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let homeserver = env_required("DICEBOT_HOMESERVER")?;
|
||||
let username = env_required("DICEBOT_USERNAME")?;
|
||||
let password = env_required("DICEBOT_PASSWORD")?;
|
||||
let allow_list = env::var("DICEBOT_ALLOW_LIST").ok();
|
||||
let state_dir = env::var("DICEBOT_STATE_DIR").ok();
|
||||
|
||||
let config = BotConfig {
|
||||
login: Login {
|
||||
homeserver_url: homeserver,
|
||||
username,
|
||||
password: Some(password),
|
||||
},
|
||||
name: Some("dicebot".to_string()),
|
||||
allow_list,
|
||||
state_dir,
|
||||
command_prefix: None,
|
||||
room_size_limit: None,
|
||||
};
|
||||
|
||||
let mut bot = Bot::new(config).await;
|
||||
bot.login().await.context("login failed")?;
|
||||
bot.join_rooms();
|
||||
bot.sync().await.context("initial sync failed")?;
|
||||
|
||||
// Use a thread-safe RNG so the handler can be Send + Sync.
|
||||
let rng = Arc::new(Mutex::new(rand::rngs::StdRng::from_entropy()));
|
||||
|
||||
{
|
||||
let rng = rng.clone();
|
||||
bot.register_text_handler(move |_sender, body, room, _event| {
|
||||
let rng = rng.clone();
|
||||
async move {
|
||||
let Some(expr) = extract_roll_request(&body) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let reply = match roll_and_format(expr, &rng).await {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => format!("⚠️ {e}"),
|
||||
};
|
||||
|
||||
let content = RoomMessageEventContent::text_markdown(reply);
|
||||
if let Err(e) = room.send(content).await {
|
||||
tracing::warn!("failed to send reply: {e}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bot.run().await.context("bot terminated")
|
||||
}
|
||||
|
||||
/// Returns Some(expr) if the message is a `!roll` invocation, else None.
|
||||
///
|
||||
/// Accepts the prefix case-insensitively, requires a word boundary after it
|
||||
/// so that "!rolled" doesn't match, and trims surrounding whitespace.
|
||||
fn extract_roll_request(body: &str) -> Option<&str> {
|
||||
let trimmed = body.trim_start();
|
||||
if !trimmed.get(..COMMAND_PREFIX.len())?.eq_ignore_ascii_case(COMMAND_PREFIX) {
|
||||
return None;
|
||||
}
|
||||
let rest = &trimmed[COMMAND_PREFIX.len()..];
|
||||
// require either end-of-string or whitespace after the prefix
|
||||
match rest.chars().next() {
|
||||
None => Some(""),
|
||||
Some(c) if c.is_whitespace() => Some(rest.trim()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn roll_and_format(
|
||||
expr: &str,
|
||||
rng: &Arc<Mutex<rand::rngs::StdRng>>,
|
||||
) -> Result<String, String> {
|
||||
if expr.is_empty() {
|
||||
return Ok(usage());
|
||||
}
|
||||
let mut guard = rng.lock().await;
|
||||
let results = dice::roll_many(expr, &mut *guard);
|
||||
|
||||
// Single roll: one-liner. Multiple: bullet list.
|
||||
if results.len() == 1 {
|
||||
return match results.into_iter().next().unwrap() {
|
||||
Ok(r) => Ok(format!("🎲 {}", r)),
|
||||
Err(e) => Err(format!("couldn't roll `{expr}`: {e}")),
|
||||
};
|
||||
}
|
||||
|
||||
let mut out = String::from("🎲\n");
|
||||
for r in results {
|
||||
match r {
|
||||
Ok(roll) => {
|
||||
out.push_str("- ");
|
||||
out.push_str(&roll.to_string());
|
||||
out.push('\n');
|
||||
}
|
||||
Err(e) => {
|
||||
out.push_str(&format!("- ⚠️ {e}\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn usage() -> String {
|
||||
"\
|
||||
**dicebot** — usage: `!roll <expr>[; <expr>…]`
|
||||
|
||||
Examples:
|
||||
- `!roll 1d20+3` — single d20 with modifier
|
||||
- `!roll 3d6` / `!roll 4d6kh3` — stat gen (roll, or drop lowest)
|
||||
- `!roll d%` — percentile (alias for `1d100`)
|
||||
- `!roll 1d20 <= 13 [save vs poison]` — check with label
|
||||
- `!roll 1d20+3; 1d8+1 [dmg]` — multiple rolls in one message
|
||||
|
||||
Ops: `+ - * /`, `()`, `kh N` / `kl N` (keep highest/lowest), \
|
||||
`<= >= < > == !=` for target checks."
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extract_accepts_prefix() {
|
||||
assert_eq!(extract_roll_request("!roll 2d6"), Some("2d6"));
|
||||
assert_eq!(extract_roll_request(" !roll 1d20+3"), Some("1d20+3"));
|
||||
assert_eq!(extract_roll_request("!ROLL 1d20"), Some("1d20"));
|
||||
assert_eq!(extract_roll_request("!roll"), Some(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_rejects_non_prefix() {
|
||||
assert_eq!(extract_roll_request("!rolled 2d6"), None);
|
||||
assert_eq!(extract_roll_request("hey !roll 2d6"), None);
|
||||
assert_eq!(extract_roll_request(""), None);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue