initial commit

This commit is contained in:
Ellie 2026-04-19 22:06:01 -07:00
commit 86e4751dcc
7 changed files with 5007 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
target
result

3869
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

24
Cargo.toml Normal file
View file

@ -0,0 +1,24 @@
[package]
name = "dicebot"
version = "0.1.0"
edition = "2021"
description = "Matrix dice-rolling bot built on headjack (matrix-rust-sdk + vodozemac)"
license = "AGPL-3.0-or-later"
[dependencies]
headjack = "0.5"
# matrix-sdk must match the version headjack re-exports (0.7).
# headjack pins matrix-sdk 0.7 with default features (including native-tls);
# we match so cargo doesn't complain about conflicting tls features.
matrix-sdk = "0.7"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] }
anyhow = "1"
rand = "0.8"
thiserror = "1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[profile.release]
lto = "thin"
codegen-units = 1
strip = true

27
flake.lock generated Normal file
View file

@ -0,0 +1,27 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1776434932,
"narHash": "sha256-gyqXNMgk3sh+ogY5svd2eNLJ6oEwzbAeaoBrrxD0lKk=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "c7f47036d3df2add644c46d712d14262b7d86c0c",
"type": "github"
},
"original": {
"owner": "nixos",
"ref": "nixos-25.11",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

64
flake.nix Normal file
View file

@ -0,0 +1,64 @@
{
description = "Matrix dice-rolling bot (headjack + vodozemac)";
inputs = {
nixpkgs.url = "github:nixos/nixpkgs/nixos-25.11";
};
outputs =
{ nixpkgs, ... }:
let
system = "x86_64-linux";
pkgs = import nixpkgs { inherit system; };
in
{
packages.${system} = rec {
dicebot = pkgs.rustPlatform.buildRustPackage {
pname = "dicebot";
version = "0.1.0";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
nativeBuildInputs = with pkgs; [
pkg-config
rustPlatform.bindgenHook
];
buildInputs = with pkgs; [
openssl
sqlite
];
# One sqlite crate's build script still invokes pkg-config → openssl.
env.OPENSSL_NO_VENDOR = "1";
meta = with pkgs.lib; {
description = "Matrix dice-rolling bot";
license = licenses.agpl3Plus;
mainProgram = "dicebot";
platforms = platforms.linux;
};
};
default = dicebot;
};
devShells.${system}.default = pkgs.mkShell {
packages = with pkgs; [
cargo
rustc
rustfmt
clippy
pkg-config
openssl
sqlite
];
};
formatter.${system} = pkgs.nixfmt-tree;
};
}

856
src/dice.rs Normal file
View file

@ -0,0 +1,856 @@
//! Dice DSL parser and evaluator.
//!
//! Grammar (top-level):
//!
//! ```text
//! input = roll (';' roll)*
//! roll = expr (cmp NUM)? ('[' LABEL ']')?
//! cmp = '<=' | '>=' | '<' | '>' | '==' | '!='
//! expr = term (('+'|'-') term)*
//! term = unary (('*'|'/') unary)*
//! unary = '-' unary | atom
//! atom = '(' expr ')' | NUM | [N] 'd' (NUM | '%') (KEEP)?
//! keep = ('kh'|'kl') NUM
//! ```
//!
//! `d%` is shorthand for `d100`. Labels are any text not containing `]`.
use rand::Rng;
use std::fmt::{self, Write};
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Num(i64),
Dice {
count: u32,
sides: u32,
keep: Option<Keep>,
},
BinOp(Op, Box<Expr>, Box<Expr>),
Neg(Box<Expr>),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Keep {
High(u32),
Low(u32),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Op {
Add,
Sub,
Mul,
Div,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CmpOp {
Le,
Ge,
Lt,
Gt,
Eq,
Ne,
}
impl CmpOp {
fn apply(self, lhs: i64, rhs: i64) -> bool {
match self {
CmpOp::Le => lhs <= rhs,
CmpOp::Ge => lhs >= rhs,
CmpOp::Lt => lhs < rhs,
CmpOp::Gt => lhs > rhs,
CmpOp::Eq => lhs == rhs,
CmpOp::Ne => lhs != rhs,
}
}
fn symbol(self) -> &'static str {
match self {
CmpOp::Le => "",
CmpOp::Ge => "",
CmpOp::Lt => "<",
CmpOp::Gt => ">",
CmpOp::Eq => "=",
CmpOp::Ne => "",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RollSpec {
pub expr: Expr,
pub cmp: Option<(CmpOp, i64)>,
pub label: Option<String>,
/// Original input slice for echoing in replies (trimmed, without the
/// optional `[label]` suffix).
pub source: String,
}
// ---------- parser ----------
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum ParseError {
#[error("expected {expected} at position {pos}, got {got}")]
Expected {
expected: &'static str,
pos: usize,
got: String,
},
#[error("number out of range")]
NumberRange,
#[error("dice count must be at least 1")]
ZeroDice,
#[error("dice must have at least 1 side")]
ZeroSides,
#[error("keep count {keep} exceeds dice count {count}")]
KeepTooLarge { keep: u32, count: u32 },
#[error("trailing input at position {0}")]
Trailing(usize),
#[error("unterminated label (missing ']')")]
UnterminatedLabel,
#[error("empty roll")]
Empty,
}
struct Parser<'a> {
src: &'a [u8],
pos: usize,
}
/// Parse a single roll spec from one segment (no `;` splitting).
pub fn parse(input: &str) -> Result<RollSpec, ParseError> {
let trimmed = input.trim();
if trimmed.is_empty() {
return Err(ParseError::Empty);
}
let mut p = Parser {
src: trimmed.as_bytes(),
pos: 0,
};
p.skip_ws();
let expr = p.parse_expr()?;
p.skip_ws();
let cmp = p.parse_cmp()?;
p.skip_ws();
let label = p.parse_label()?;
p.skip_ws();
if p.pos != p.src.len() {
return Err(ParseError::Trailing(p.pos));
}
// source = trimmed minus the [label] suffix, retrimmed
let source = match &label {
Some(_) => {
// find the '[' by scanning backwards for the last bracket
let bytes = trimmed.as_bytes();
let mut bracket = None;
for (i, b) in bytes.iter().enumerate().rev() {
if *b == b'[' {
bracket = Some(i);
break;
}
}
bracket
.map(|i| trimmed[..i].trim_end().to_string())
.unwrap_or_else(|| trimmed.to_string())
}
None => trimmed.to_string(),
};
Ok(RollSpec {
expr,
cmp,
label,
source,
})
}
/// Parse a full `input` that may contain multiple `;`-separated rolls.
/// Returns one Result per segment, in order.
pub fn parse_many(input: &str) -> Vec<Result<RollSpec, ParseError>> {
input.split(';').map(parse).collect()
}
impl<'a> Parser<'a> {
fn peek(&self) -> Option<u8> {
self.src.get(self.pos).copied()
}
fn peek2(&self) -> Option<u8> {
self.src.get(self.pos + 1).copied()
}
fn skip_ws(&mut self) {
while let Some(b) = self.peek() {
if b.is_ascii_whitespace() {
self.pos += 1;
} else {
break;
}
}
}
fn eat(&mut self, c: u8) -> bool {
if self.peek() == Some(c) {
self.pos += 1;
true
} else {
false
}
}
fn eat_ci(&mut self, lit: &str) -> bool {
let bytes = lit.as_bytes();
if self.src.len() - self.pos < bytes.len() {
return false;
}
if self.src[self.pos..self.pos + bytes.len()]
.iter()
.zip(bytes)
.all(|(a, b)| a.eq_ignore_ascii_case(b))
{
self.pos += bytes.len();
true
} else {
false
}
}
fn parse_number(&mut self) -> Result<u32, ParseError> {
let start = self.pos;
while let Some(b) = self.peek() {
if b.is_ascii_digit() {
self.pos += 1;
} else {
break;
}
}
if start == self.pos {
return Err(ParseError::Expected {
expected: "number",
pos: start,
got: self.snippet(),
});
}
std::str::from_utf8(&self.src[start..self.pos])
.unwrap()
.parse::<u32>()
.map_err(|_| ParseError::NumberRange)
}
fn parse_signed_number(&mut self) -> Result<i64, ParseError> {
let negate = self.eat(b'-');
self.skip_ws();
let n = self.parse_number()? as i64;
Ok(if negate { -n } else { n })
}
fn snippet(&self) -> String {
match self.peek() {
Some(b) => format!("{:?}", b as char),
None => "end of input".to_string(),
}
}
fn parse_expr(&mut self) -> Result<Expr, ParseError> {
let mut lhs = self.parse_term()?;
loop {
self.skip_ws();
let op = match self.peek() {
Some(b'+') => Op::Add,
Some(b'-') => Op::Sub,
_ => break,
};
self.pos += 1;
self.skip_ws();
let rhs = self.parse_term()?;
lhs = Expr::BinOp(op, Box::new(lhs), Box::new(rhs));
}
Ok(lhs)
}
fn parse_term(&mut self) -> Result<Expr, ParseError> {
let mut lhs = self.parse_unary()?;
loop {
self.skip_ws();
let op = match self.peek() {
Some(b'*') => Op::Mul,
Some(b'/') => Op::Div,
_ => break,
};
self.pos += 1;
self.skip_ws();
let rhs = self.parse_unary()?;
lhs = Expr::BinOp(op, Box::new(lhs), Box::new(rhs));
}
Ok(lhs)
}
fn parse_unary(&mut self) -> Result<Expr, ParseError> {
self.skip_ws();
if self.eat(b'-') {
let inner = self.parse_unary()?;
return Ok(Expr::Neg(Box::new(inner)));
}
self.parse_atom()
}
fn parse_atom(&mut self) -> Result<Expr, ParseError> {
self.skip_ws();
if self.eat(b'(') {
let e = self.parse_expr()?;
self.skip_ws();
if !self.eat(b')') {
return Err(ParseError::Expected {
expected: "')'",
pos: self.pos,
got: self.snippet(),
});
}
return Ok(e);
}
let had_digits = matches!(self.peek(), Some(b) if b.is_ascii_digit());
let leading = if had_digits { self.parse_number()? } else { 1 };
if matches!(self.peek(), Some(b'd' | b'D')) {
self.pos += 1;
// d% shorthand for d100
let sides = if self.eat(b'%') {
100
} else {
self.parse_number()?
};
if leading == 0 {
return Err(ParseError::ZeroDice);
}
if sides == 0 {
return Err(ParseError::ZeroSides);
}
let keep = self.parse_keep()?;
if let Some(k) = keep {
let keep_n = match k {
Keep::High(n) | Keep::Low(n) => n,
};
if keep_n == 0 {
return Err(ParseError::Expected {
expected: "non-zero keep count",
pos: self.pos,
got: "0".into(),
});
}
if keep_n > leading {
return Err(ParseError::KeepTooLarge {
keep: keep_n,
count: leading,
});
}
}
Ok(Expr::Dice {
count: leading,
sides,
keep,
})
} else if had_digits {
Ok(Expr::Num(leading as i64))
} else {
Err(ParseError::Expected {
expected: "number, dice, or '('",
pos: self.pos,
got: self.snippet(),
})
}
}
fn parse_keep(&mut self) -> Result<Option<Keep>, ParseError> {
if self.eat_ci("kh") {
let n = self.parse_number()?;
return Ok(Some(Keep::High(n)));
}
if self.eat_ci("kl") {
let n = self.parse_number()?;
return Ok(Some(Keep::Low(n)));
}
Ok(None)
}
// Comparison: parses `<=`, `>=`, `<`, `>`, `==`, `!=` followed by an integer.
// Note `<` and `>` must be tried *after* `<=` / `>=`.
fn parse_cmp(&mut self) -> Result<Option<(CmpOp, i64)>, ParseError> {
let op = if self.peek() == Some(b'<') && self.peek2() == Some(b'=') {
self.pos += 2;
CmpOp::Le
} else if self.peek() == Some(b'>') && self.peek2() == Some(b'=') {
self.pos += 2;
CmpOp::Ge
} else if self.peek() == Some(b'=') && self.peek2() == Some(b'=') {
self.pos += 2;
CmpOp::Eq
} else if self.peek() == Some(b'!') && self.peek2() == Some(b'=') {
self.pos += 2;
CmpOp::Ne
} else if self.peek() == Some(b'<') {
self.pos += 1;
CmpOp::Lt
} else if self.peek() == Some(b'>') {
self.pos += 1;
CmpOp::Gt
} else {
return Ok(None);
};
self.skip_ws();
let n = self.parse_signed_number()?;
Ok(Some((op, n)))
}
fn parse_label(&mut self) -> Result<Option<String>, ParseError> {
if !self.eat(b'[') {
return Ok(None);
}
let start = self.pos;
while let Some(b) = self.peek() {
if b == b']' {
let label = std::str::from_utf8(&self.src[start..self.pos])
.unwrap()
.trim()
.to_string();
self.pos += 1;
return Ok(Some(label));
}
self.pos += 1;
}
Err(ParseError::UnterminatedLabel)
}
}
// ---------- evaluator ----------
#[derive(Debug, thiserror::Error)]
pub enum EvalError {
#[error("division by zero")]
DivByZero,
#[error("too many dice (limit 1000)")]
TooManyDice,
#[error("arithmetic overflow")]
Overflow,
}
pub const MAX_DICE: u32 = 1000;
#[derive(Debug, Clone)]
pub struct Roll {
pub source: String,
pub label: Option<String>,
pub total: i64,
/// Human-readable evaluation trace, e.g. `[18, ~~3~~] + 2 = **21**`.
pub trace: String,
/// If a comparison was supplied, the outcome and threshold.
pub check: Option<CheckResult>,
}
#[derive(Debug, Clone, Copy)]
pub struct CheckResult {
pub op: CmpOp,
pub target: i64,
pub passed: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum DiceError {
#[error(transparent)]
Parse(#[from] ParseError),
#[error(transparent)]
Eval(#[from] EvalError),
}
pub fn roll(input: &str, rng: &mut impl Rng) -> Result<Roll, DiceError> {
let spec = parse(input)?;
eval_spec(&spec, rng)
}
/// Roll multiple `;`-separated expressions. Each result is independent.
pub fn roll_many(input: &str, rng: &mut impl Rng) -> Vec<Result<Roll, DiceError>> {
parse_many(input)
.into_iter()
.map(|r| match r {
Ok(spec) => eval_spec(&spec, rng),
Err(e) => Err(e.into()),
})
.collect()
}
fn eval_spec(spec: &RollSpec, rng: &mut impl Rng) -> Result<Roll, DiceError> {
let mut trace = String::new();
let total = eval(&spec.expr, rng, &mut trace)?;
let _ = write!(trace, " = **{}**", total);
let check = spec.cmp.map(|(op, target)| CheckResult {
op,
target,
passed: op.apply(total, target),
});
Ok(Roll {
source: spec.source.clone(),
label: spec.label.clone(),
total,
trace,
check,
})
}
fn eval(expr: &Expr, rng: &mut impl Rng, out: &mut String) -> Result<i64, EvalError> {
match expr {
Expr::Num(n) => {
let _ = write!(out, "{}", n);
Ok(*n)
}
Expr::Neg(inner) => {
out.push('-');
let v = eval(inner, rng, out)?;
v.checked_neg().ok_or(EvalError::Overflow)
}
Expr::Dice { count, sides, keep } => roll_dice(*count, *sides, *keep, rng, out),
Expr::BinOp(op, a, b) => {
let lhs = eval(a, rng, out)?;
out.push(' ');
out.push(match op {
Op::Add => '+',
Op::Sub => '-',
Op::Mul => '*',
Op::Div => '/',
});
out.push(' ');
let rhs = eval(b, rng, out)?;
match op {
Op::Add => lhs.checked_add(rhs).ok_or(EvalError::Overflow),
Op::Sub => lhs.checked_sub(rhs).ok_or(EvalError::Overflow),
Op::Mul => lhs.checked_mul(rhs).ok_or(EvalError::Overflow),
Op::Div => {
if rhs == 0 {
Err(EvalError::DivByZero)
} else {
lhs.checked_div(rhs).ok_or(EvalError::Overflow)
}
}
}
}
}
}
fn roll_dice(
count: u32,
sides: u32,
keep: Option<Keep>,
rng: &mut impl Rng,
out: &mut String,
) -> Result<i64, EvalError> {
if count > MAX_DICE {
return Err(EvalError::TooManyDice);
}
let rolls: Vec<u32> = (0..count).map(|_| rng.gen_range(1..=sides)).collect();
let kept: Vec<bool> = match keep {
None => vec![true; rolls.len()],
Some(Keep::High(k)) => mark_kept(&rolls, k, true),
Some(Keep::Low(k)) => mark_kept(&rolls, k, false),
};
let total: i64 = rolls
.iter()
.zip(&kept)
.filter(|(_, keep)| **keep)
.map(|(v, _)| *v as i64)
.sum();
out.push('[');
for (i, (v, keep)) in rolls.iter().zip(&kept).enumerate() {
if i > 0 {
out.push_str(", ");
}
if *keep {
let _ = write!(out, "{}", v);
} else {
let _ = write!(out, "~~{}~~", v);
}
}
out.push(']');
Ok(total)
}
fn mark_kept(rolls: &[u32], k: u32, high: bool) -> Vec<bool> {
let mut indexed: Vec<(usize, u32)> = rolls.iter().copied().enumerate().collect();
if high {
indexed.sort_by(|a, b| b.1.cmp(&a.1));
} else {
indexed.sort_by(|a, b| a.1.cmp(&b.1));
}
let mut kept = vec![false; rolls.len()];
for (idx, _) in indexed.into_iter().take(k as usize) {
kept[idx] = true;
}
kept
}
impl fmt::Display for Roll {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// format: `1d20+2 [save]`: [18] + 2 = **20** ✓ (≤ 20)
if let Some(label) = &self.label {
write!(f, "`{}` [{}]: {}", self.source, label, self.trace)?;
} else {
write!(f, "`{}`: {}", self.source, self.trace)?;
}
if let Some(c) = &self.check {
let mark = if c.passed { "" } else { "" };
write!(f, " {} ({} {})", mark, c.op.symbol(), c.target)?;
}
Ok(())
}
}
// ---------- tests ----------
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::mock::StepRng;
fn p(s: &str) -> RollSpec {
parse(s).unwrap()
}
#[test]
fn parse_literal() {
let r = p("42");
assert_eq!(r.expr, Expr::Num(42));
assert!(r.cmp.is_none());
assert!(r.label.is_none());
}
#[test]
fn parse_bare_d() {
assert_eq!(
p("d20").expr,
Expr::Dice {
count: 1,
sides: 20,
keep: None
}
);
}
#[test]
fn parse_ndm() {
assert_eq!(
p("3d6").expr,
Expr::Dice {
count: 3,
sides: 6,
keep: None
}
);
}
#[test]
fn parse_percentile_alias() {
assert_eq!(
p("d%").expr,
Expr::Dice {
count: 1,
sides: 100,
keep: None
}
);
assert_eq!(
p("1d%").expr,
Expr::Dice {
count: 1,
sides: 100,
keep: None
}
);
}
#[test]
fn parse_kh() {
assert_eq!(
p("2d20kh1").expr,
Expr::Dice {
count: 2,
sides: 20,
keep: Some(Keep::High(1))
}
);
}
#[test]
fn parse_kl_case_insensitive() {
assert_eq!(
p("4D6KL3").expr,
Expr::Dice {
count: 4,
sides: 6,
keep: Some(Keep::Low(3))
}
);
}
#[test]
fn parse_arith() {
assert_eq!(
p("1d20 + 3").expr,
Expr::BinOp(
Op::Add,
Box::new(Expr::Dice {
count: 1,
sides: 20,
keep: None
}),
Box::new(Expr::Num(3))
)
);
}
#[test]
fn parse_comparison() {
let r = p("1d20+2 <= 14");
assert_eq!(r.cmp, Some((CmpOp::Le, 14)));
}
#[test]
fn parse_comparison_all_ops() {
assert_eq!(p("1d6 < 3").cmp, Some((CmpOp::Lt, 3)));
assert_eq!(p("1d6 > 3").cmp, Some((CmpOp::Gt, 3)));
assert_eq!(p("1d6 <= 3").cmp, Some((CmpOp::Le, 3)));
assert_eq!(p("1d6 >= 3").cmp, Some((CmpOp::Ge, 3)));
assert_eq!(p("1d6 == 3").cmp, Some((CmpOp::Eq, 3)));
assert_eq!(p("1d6 != 3").cmp, Some((CmpOp::Ne, 3)));
}
#[test]
fn parse_negative_target() {
assert_eq!(p("1d6-10 <= -2").cmp, Some((CmpOp::Le, -2)));
}
#[test]
fn parse_label() {
let r = p("1d20+2 [save vs poison]");
assert_eq!(r.label.as_deref(), Some("save vs poison"));
assert_eq!(r.source, "1d20+2");
}
#[test]
fn parse_label_with_cmp() {
let r = p("1d20 <= 13 [save]");
assert_eq!(r.label.as_deref(), Some("save"));
assert_eq!(r.cmp, Some((CmpOp::Le, 13)));
}
#[test]
fn parse_unterminated_label() {
assert_eq!(parse("1d20 [oops"), Err(ParseError::UnterminatedLabel));
}
#[test]
fn parse_many_splits() {
let rs = parse_many("1d20+3; 1d8+1 [dmg]; d%");
assert_eq!(rs.len(), 3);
assert!(rs.iter().all(|r| r.is_ok()));
assert_eq!(rs[1].as_ref().unwrap().label.as_deref(), Some("dmg"));
}
#[test]
fn parse_rejects_trailing() {
assert!(matches!(parse("1d20 foo"), Err(ParseError::Trailing(_))));
}
#[test]
fn parse_rejects_zero_sides() {
assert!(matches!(parse("1d0"), Err(ParseError::ZeroSides)));
}
#[test]
fn parse_rejects_keep_too_large() {
assert!(matches!(
parse("2d20kh3"),
Err(ParseError::KeepTooLarge { .. })
));
}
#[test]
fn eval_paren_precedence() {
let mut rng = StepRng::new(0, 0);
let r = roll("(1+2)*3", &mut rng).unwrap();
assert_eq!(r.total, 9);
}
#[test]
fn eval_comparison_passes() {
let mut rng = rand::thread_rng();
for _ in 0..20 {
let r = roll("1d6 <= 6", &mut rng).unwrap();
let c = r.check.unwrap();
assert!(c.passed);
}
}
#[test]
fn eval_comparison_fails() {
let mut rng = rand::thread_rng();
for _ in 0..20 {
let r = roll("1d6 > 6", &mut rng).unwrap();
let c = r.check.unwrap();
assert!(!c.passed);
}
}
#[test]
fn eval_kh_keeps_highest() {
let mut rng = rand::thread_rng();
for _ in 0..50 {
let r = roll("3d6kh2", &mut rng).unwrap();
assert!(
(2..=12).contains(&r.total),
"sum of top 2 d6 must be in [2,12], got {}",
r.total
);
}
}
#[test]
fn eval_div_by_zero() {
let mut rng = rand::thread_rng();
assert!(matches!(
roll("5/0", &mut rng),
Err(DiceError::Eval(EvalError::DivByZero))
));
}
#[test]
fn eval_too_many_dice() {
let mut rng = rand::thread_rng();
assert!(matches!(
roll("1001d6", &mut rng),
Err(DiceError::Eval(EvalError::TooManyDice))
));
}
#[test]
fn display_without_label() {
let mut rng = StepRng::new(0, 0);
let r = roll("5+3", &mut rng).unwrap();
assert_eq!(r.to_string(), "`5+3`: 5 + 3 = **8**");
}
#[test]
fn display_with_label_and_check() {
let mut rng = StepRng::new(0, 0);
let r = roll("5+3 <= 10 [check]", &mut rng).unwrap();
assert_eq!(r.to_string(), "`5+3 <= 10` [check]: 5 + 3 = **8** ✓ (≤ 10)");
}
}

165
src/main.rs Normal file
View file

@ -0,0 +1,165 @@
mod dice;
use anyhow::{Context, Result};
use headjack::{Bot, BotConfig, Login};
use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
use rand::SeedableRng;
use std::env;
use std::sync::Arc;
use tokio::sync::Mutex;
const COMMAND_PREFIX: &str = "!roll";
fn env_required(name: &str) -> Result<String> {
env::var(name).with_context(|| format!("{name} must be set"))
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,headjack=info")),
)
.init();
let homeserver = env_required("DICEBOT_HOMESERVER")?;
let username = env_required("DICEBOT_USERNAME")?;
let password = env_required("DICEBOT_PASSWORD")?;
let allow_list = env::var("DICEBOT_ALLOW_LIST").ok();
let state_dir = env::var("DICEBOT_STATE_DIR").ok();
let config = BotConfig {
login: Login {
homeserver_url: homeserver,
username,
password: Some(password),
},
name: Some("dicebot".to_string()),
allow_list,
state_dir,
command_prefix: None,
room_size_limit: None,
};
let mut bot = Bot::new(config).await;
bot.login().await.context("login failed")?;
bot.join_rooms();
bot.sync().await.context("initial sync failed")?;
// Use a thread-safe RNG so the handler can be Send + Sync.
let rng = Arc::new(Mutex::new(rand::rngs::StdRng::from_entropy()));
{
let rng = rng.clone();
bot.register_text_handler(move |_sender, body, room, _event| {
let rng = rng.clone();
async move {
let Some(expr) = extract_roll_request(&body) else {
return Ok(());
};
let reply = match roll_and_format(expr, &rng).await {
Ok(msg) => msg,
Err(e) => format!("⚠️ {e}"),
};
let content = RoomMessageEventContent::text_markdown(reply);
if let Err(e) = room.send(content).await {
tracing::warn!("failed to send reply: {e}");
}
Ok(())
}
});
}
bot.run().await.context("bot terminated")
}
/// Returns Some(expr) if the message is a `!roll` invocation, else None.
///
/// Accepts the prefix case-insensitively, requires a word boundary after it
/// so that "!rolled" doesn't match, and trims surrounding whitespace.
fn extract_roll_request(body: &str) -> Option<&str> {
let trimmed = body.trim_start();
if !trimmed.get(..COMMAND_PREFIX.len())?.eq_ignore_ascii_case(COMMAND_PREFIX) {
return None;
}
let rest = &trimmed[COMMAND_PREFIX.len()..];
// require either end-of-string or whitespace after the prefix
match rest.chars().next() {
None => Some(""),
Some(c) if c.is_whitespace() => Some(rest.trim()),
_ => None,
}
}
async fn roll_and_format(
expr: &str,
rng: &Arc<Mutex<rand::rngs::StdRng>>,
) -> Result<String, String> {
if expr.is_empty() {
return Ok(usage());
}
let mut guard = rng.lock().await;
let results = dice::roll_many(expr, &mut *guard);
// Single roll: one-liner. Multiple: bullet list.
if results.len() == 1 {
return match results.into_iter().next().unwrap() {
Ok(r) => Ok(format!("🎲 {}", r)),
Err(e) => Err(format!("couldn't roll `{expr}`: {e}")),
};
}
let mut out = String::from("🎲\n");
for r in results {
match r {
Ok(roll) => {
out.push_str("- ");
out.push_str(&roll.to_string());
out.push('\n');
}
Err(e) => {
out.push_str(&format!("- ⚠️ {e}\n"));
}
}
}
Ok(out)
}
fn usage() -> String {
"\
**dicebot** usage: `!roll <expr>[; <expr>]`
Examples:
- `!roll 1d20+3` single d20 with modifier
- `!roll 3d6` / `!roll 4d6kh3` stat gen (roll, or drop lowest)
- `!roll d%` percentile (alias for `1d100`)
- `!roll 1d20 <= 13 [save vs poison]` check with label
- `!roll 1d20+3; 1d8+1 [dmg]` multiple rolls in one message
Ops: `+ - * /`, `()`, `kh N` / `kl N` (keep highest/lowest), \
`<= >= < > == !=` for target checks."
.to_string()
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn extract_accepts_prefix() {
assert_eq!(extract_roll_request("!roll 2d6"), Some("2d6"));
assert_eq!(extract_roll_request(" !roll 1d20+3"), Some("1d20+3"));
assert_eq!(extract_roll_request("!ROLL 1d20"), Some("1d20"));
assert_eq!(extract_roll_request("!roll"), Some(""));
}
#[test]
fn extract_rejects_non_prefix() {
assert_eq!(extract_roll_request("!rolled 2d6"), None);
assert_eq!(extract_roll_request("hey !roll 2d6"), None);
assert_eq!(extract_roll_request(""), None);
}
}