/// model/exponential.rs, BAYES STAR, (c) coppola.ai 2024 use super::choose::extract_backimplications_from_proposition; use super::config::ConfigurationOptions; use super::objects::PredicateFactor; use super::weights::{negative_feature, positive_feature, ExponentialWeights}; use crate::common::interface::{PropositionDB, PredictStatistics, TrainStatistics}; use crate::common::model::InferenceModel; use crate::common::model::{FactorContext, FactorModel}; use crate::common::redis::RedisManager; use crate::common::resources::FactoryResources; use crate::model::objects::Predicate; use crate::model::weights::CLASS_LABELS; use crate::{print_yellow, print_blue}; use redis::Connection; use std::cell::RefCell; use std::collections::HashMap; use std::error::Error; use std::rc::Rc; pub struct ExponentialModel { config: ConfigurationOptions, weights: ExponentialWeights, } impl ExponentialModel { pub fn new_mutable(resources: &FactoryResources) -> Result, Box> { let connection = resources.redis.get_connection()?; let weights = ExponentialWeights::new(connection); Ok(Box::new(ExponentialModel { config: resources.config.clone(), weights, })) } pub fn new_shared(resources: &FactoryResources) -> Result, Box> { let connection = resources.redis.get_connection()?; let weights = ExponentialWeights::new(connection); Ok(Rc::new(ExponentialModel { config: resources.config.clone(), weights, })) } } fn dot_product(dict1: &HashMap, dict2: &HashMap) -> f64 { let mut result = 0.0; for (key, &v1) in dict1 { if let Some(&v2) = dict2.get(key) { let product = v1 * v2; print_blue!("dot_product: key {}, v1 {}, v2 {}, product {}", key, v1, v2, product); result += product; } // In case of null (None), we skip the key as per the original JavaScript logic. } result } pub fn compute_potential(weights: &HashMap, features: &HashMap) -> f64 { let dot = dot_product(weights, features); dot.exp() } pub fn features_from_factor( factor: &FactorContext, ) -> Result>, Box> { let mut vec_result = vec![]; for class_label in CLASS_LABELS { let mut result = HashMap::new(); for (i, premise) in factor.factor.iter().enumerate() { debug!("Processing backimplication {}", i); let feature = premise.inference.unique_key(); debug!("Generated unique key for feature: {}", feature); let probability = factor.probabilities[i]; debug!( "Conjunction probability for backimplication {}: {}", i, probability ); let posf = positive_feature(&feature, class_label); let negf = negative_feature(&feature, class_label); result.insert(posf.clone(), probability); result.insert(negf.clone(), 1.0 - probability); debug!( "Inserted features for backimplication {}: positive - {}, negative - {}", i, posf, negf ); } vec_result.push(result); } trace!("features_from_backimplications completed successfully"); Ok(vec_result) } pub fn compute_expected_features( probability: f64, features: &HashMap, ) -> HashMap { let mut result = HashMap::new(); for (key, &value) in features { result.insert(key.clone(), value * probability); } result } const LEARNING_RATE: f64 = 0.025; pub fn do_sgd_update( weights: &HashMap, gold_features: &HashMap, expected_features: &HashMap, print_training_loss: bool, ) -> HashMap { let mut new_weights = HashMap::new(); for (feature, &wv) in weights { let gv = gold_features.get(feature).unwrap_or(&0.0); let ev = expected_features.get(feature).unwrap_or(&0.0); let new_weight = wv + LEARNING_RATE * (gv - ev); let loss = (gv - ev).abs(); if print_training_loss { info!( "feature: {}, gv: {}, ev: {}, loss: {}, old_weight: {}, new_weight: {}", feature, gv, ev, loss, wv, new_weight ); } new_weights.insert(feature.clone(), new_weight); } new_weights } impl FactorModel for ExponentialModel { fn initialize_connection( &mut self, implication: &PredicateFactor, ) -> Result<(), Box> { self.weights.initialize_weights(implication)?; Ok(()) } fn train( &mut self, factor: &FactorContext, gold_probability: f64, ) -> Result> { trace!("train_on_example - Getting features from backimplications"); let features = match features_from_factor(factor) { Ok(f) => f, Err(e) => { trace!( "train_on_example - Error in features_from_backimplications: {:?}", e ); return Err(e); } }; let mut weight_vectors = vec![]; let mut potentials = vec![]; for class_label in CLASS_LABELS { for (feature, weight) in &features[class_label] { trace!("feature {:?} {}", feature, weight); } trace!( "train_on_example - Reading weights for class {}", class_label ); let weight_vector = match self .weights .read_weights(&features[class_label].keys().cloned().collect::>()) { Ok(w) => w, Err(e) => { trace!("train_on_example - Error in read_weights: {:?}", e); return Err(e); } }; trace!("train_on_example - Computing probability"); let potential = compute_potential(&weight_vector, &features[class_label]); trace!("train_on_example - Computed probability: {}", potential); potentials.push(potential); weight_vectors.push(weight_vector); } let normalization = potentials[0] + potentials[1]; for class_label in CLASS_LABELS { let probability = potentials[class_label] / normalization; trace!("train_on_example - Computing expected features"); let this_true_prob = if class_label == 0 { 1f64 - gold_probability } else { gold_probability }; let gold = compute_expected_features(this_true_prob, &features[class_label]); let expected = compute_expected_features(probability, &features[class_label]); trace!("train_on_example - Performing SGD update"); let new_weight = do_sgd_update( &weight_vectors[class_label], &gold, &expected, self.config.print_training_loss, ); trace!("train_on_example - Saving new weights"); self.weights.save_weights(&new_weight)?; } trace!("train_on_example - End"); Ok(TrainStatistics { loss: 1f64 }) } fn predict(&self, factor: &FactorContext) -> Result> { let features = match features_from_factor(factor) { Ok(f) => f, Err(e) => { print_yellow!( "inference_probability - Error in features_from_backimplications: {:?}", e ); return Err(e); } }; let mut potentials = vec![]; for class_label in CLASS_LABELS { let this_features = &features[class_label]; for (feature, weight) in this_features.iter() { print_yellow!("feature {:?} {}", &feature, weight); } print_yellow!("inference_probability - Reading weights"); let weight_vector = match self .weights .read_weights(&this_features.keys().cloned().collect::>()) { Ok(w) => w, Err(e) => { print_yellow!("inference_probability - Error in read_weights: {:?}", e); return Err(e); } }; for (feature, weight) in weight_vector.iter() { print_yellow!("weight {:?} {}", &feature, weight); } let potential = compute_potential(&weight_vector, &this_features); print_yellow!("potential for {} {} {:?}", class_label, potential, &factor); potentials.push(potential); } let normalization = potentials[0] + potentials[1]; let marginal = potentials[1] / normalization; print_yellow!("dot_product: normalization {}, marginal {}", normalization, marginal); Ok(PredictStatistics { marginal }) } }