Initial commit
This commit is contained in:
parent
9d19e9f2a7
commit
73e559fb27
|
|
@ -14,3 +14,4 @@ Cargo.lock
|
|||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
data/draws.csv
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "mixed_logit"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
argmin = { version = "0.5.0", features = ["ndarrayl"] }
|
||||
csv = "1.1.6"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
ndarray = "0.15.4"
|
||||
ndarray-linalg = { version = "0.14.1", features = ["openblas-system"] }
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,320 @@
|
|||
use argmin::prelude::*;
|
||||
use argmin::solver::linesearch::MoreThuenteLineSearch;
|
||||
use argmin::solver::quasinewton::BFGS;
|
||||
use ndarray::prelude::*;
|
||||
use ndarray_linalg::cholesky::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::File;
|
||||
|
||||
type Alternative = Array1<f64>;
|
||||
|
||||
fn utility(alt: &Alternative, p: ArrayView1<f64>) -> f64 {
|
||||
alt.iter().zip(p.iter()).map(|(x, b)| x * b).sum()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Observation {
|
||||
alternatives: Vec<Alternative>,
|
||||
choice: usize,
|
||||
}
|
||||
|
||||
impl Observation {
|
||||
fn nb_exog(&self) -> usize {
|
||||
self.alternatives[0].len()
|
||||
}
|
||||
|
||||
fn prob(&self, p: ArrayView1<f64>) -> f64 {
|
||||
self.prob_with_choice(p, self.choice)
|
||||
}
|
||||
|
||||
fn prob_with_choice(&self, p: ArrayView1<f64>, choice: usize) -> f64 {
|
||||
let exp_utilities: Array1<_> = self
|
||||
.alternatives
|
||||
.iter()
|
||||
.map(|alt| utility(alt, p).exp())
|
||||
.collect();
|
||||
exp_utilities[choice] / exp_utilities.sum()
|
||||
}
|
||||
|
||||
fn av_prob(&self, draws: ArrayView2<f64>) -> f64 {
|
||||
// draws: R x K
|
||||
draws.outer_iter().map(|p| self.prob(p)).sum::<f64>() / draws.len() as f64
|
||||
}
|
||||
|
||||
fn exog_bar(&self, p: ArrayView1<f64>) -> Array1<f64> {
|
||||
self.alternatives
|
||||
.iter()
|
||||
.enumerate()
|
||||
.fold(Array1::zeros(self.nb_exog()), |accum, (i, exog)| {
|
||||
accum + exog * self.prob_with_choice(p, i)
|
||||
})
|
||||
}
|
||||
|
||||
fn gradient_b(&self, p: ArrayView1<f64>) -> Array1<f64> {
|
||||
self.prob(p) * (&self.alternatives[self.choice] - self.exog_bar(p))
|
||||
}
|
||||
|
||||
fn gradient_lower(&self, p: ArrayView1<f64>, raw_draws: ArrayView1<f64>) -> Array2<f64> {
|
||||
let tmp1: ArrayView2<f64> = raw_draws.into_shape((1, raw_draws.len())).unwrap();
|
||||
let tmp2 = &self.alternatives[self.choice] - self.exog_bar(p);
|
||||
let n = tmp2.len();
|
||||
let tmp: Array2<f64> = tmp1.dot(&tmp2.into_shape((n, 1)).unwrap());
|
||||
self.prob(p) * tmp
|
||||
}
|
||||
|
||||
fn av_prob_prime(
|
||||
&self,
|
||||
draws: ArrayView2<f64>,
|
||||
raw_draws: ArrayView2<f64>,
|
||||
) -> (Array1<f64>, Array2<f64>) {
|
||||
// draws: R x K
|
||||
let init = (
|
||||
Array1::zeros(self.nb_exog()),
|
||||
Array2::zeros((self.nb_exog(), self.nb_exog())),
|
||||
);
|
||||
let gradient =
|
||||
draws
|
||||
.outer_iter()
|
||||
.zip(raw_draws.outer_iter())
|
||||
.fold(init, |accum, (p, d)| {
|
||||
(
|
||||
accum.0 + self.gradient_b(p) / draws.len() as f64,
|
||||
accum.1 + self.gradient_lower(p, d) / draws.len() as f64,
|
||||
)
|
||||
});
|
||||
let av_prob = self.av_prob(draws);
|
||||
(gradient.0 / av_prob, gradient.1 / av_prob)
|
||||
}
|
||||
}
|
||||
|
||||
fn simloglike(
|
||||
b: ArrayView1<f64>,
|
||||
w: ArrayView2<f64>,
|
||||
draws: &Array3<f64>,
|
||||
pop: &Vec<Observation>,
|
||||
) -> f64 {
|
||||
// b: K
|
||||
// w: K x K
|
||||
// lower: K x K
|
||||
// draws: N x R x K
|
||||
// betas: N x R x K
|
||||
// obs: N x K
|
||||
let mut lower = w.to_owned();
|
||||
lower.map_inplace(|v| *v = v.exp().pow(2u16));
|
||||
lower.cholesky_inplace(UPLO::Lower).unwrap();
|
||||
let mut betas = draws.clone();
|
||||
betas.outer_iter_mut().for_each(|mut obs_draws| {
|
||||
obs_draws.outer_iter_mut().for_each(|mut d| {
|
||||
let new_d = d.dot(&lower) + b;
|
||||
d.zip_mut_with(&new_d, |old, new| *old = *new);
|
||||
});
|
||||
});
|
||||
pop.iter()
|
||||
.zip(betas.outer_iter())
|
||||
.map(|(obs, obs_draws)| obs.av_prob(obs_draws).ln())
|
||||
.sum::<f64>()
|
||||
/ pop.len() as f64
|
||||
}
|
||||
|
||||
fn simloglike_prime(
|
||||
b: ArrayView1<f64>,
|
||||
w: ArrayView2<f64>,
|
||||
draws: &Array3<f64>,
|
||||
pop: &Vec<Observation>,
|
||||
) -> (Array1<f64>, Array2<f64>) {
|
||||
// b: K
|
||||
// w: K x K
|
||||
// lower: K x K
|
||||
// draws: N x R x K
|
||||
// betas: N x R x K
|
||||
// obs: N x K
|
||||
let mut lower = w.to_owned();
|
||||
lower.map_inplace(|v| *v = v.exp().pow(2u16));
|
||||
lower.cholesky_inplace(UPLO::Lower).unwrap();
|
||||
let mut betas = draws.clone();
|
||||
betas.outer_iter_mut().for_each(|mut obs_draws| {
|
||||
obs_draws.outer_iter_mut().for_each(|mut d| {
|
||||
let new_d = d.dot(&lower) + b;
|
||||
d.zip_mut_with(&new_d, |old, new| *old = *new);
|
||||
});
|
||||
});
|
||||
let init = (Array1::zeros(b.raw_dim()), Array2::zeros(w.raw_dim()));
|
||||
pop.iter()
|
||||
.zip(betas.outer_iter())
|
||||
.zip(draws.outer_iter())
|
||||
.fold(init, |accum, ((obs, obs_draws), raw_draws)| {
|
||||
let av_prob_prime = obs.av_prob_prime(obs_draws, raw_draws);
|
||||
(
|
||||
accum.0 + av_prob_prime.0 / pop.len() as f64,
|
||||
accum.1 + av_prob_prime.1 / pop.len() as f64,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
struct MixedLogit {
|
||||
population: Vec<Observation>,
|
||||
draws: Array3<f64>,
|
||||
}
|
||||
|
||||
impl MixedLogit {
|
||||
fn simloglike(&self, p: &Array1<f64>) -> f64 {
|
||||
let k = self.population[0].nb_exog();
|
||||
let b = p.slice(s![..k]);
|
||||
let w = p.slice(s![k..]).into_shape((k, k)).unwrap();
|
||||
simloglike(b, w, &self.draws, &self.population)
|
||||
}
|
||||
|
||||
fn simloglike_prime(&self, p: &Array1<f64>) -> Array1<f64> {
|
||||
let k = self.population[0].nb_exog();
|
||||
let b = p.slice(s![..k]);
|
||||
let w = p.slice(s![k..]).into_shape((k, k)).unwrap();
|
||||
let (b0, w0) = simloglike_prime(b, w, &self.draws, &self.population);
|
||||
b0.iter().chain(w0.iter()).copied().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl ArgminOp for MixedLogit {
|
||||
type Param = Array1<f64>;
|
||||
type Output = f64;
|
||||
type Hessian = Array2<f64>;
|
||||
type Jacobian = ();
|
||||
type Float = f64;
|
||||
|
||||
fn apply(&self, p: &Self::Param) -> Result<Self::Output, Error> {
|
||||
Ok(self.simloglike(p))
|
||||
}
|
||||
|
||||
fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
|
||||
Ok(self.simloglike_prime(p))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Record {
|
||||
id_obs: usize,
|
||||
a: i32,
|
||||
b: i32,
|
||||
c: i32,
|
||||
beta0: f64,
|
||||
beta1: f64,
|
||||
beta2: f64,
|
||||
epsilon: f64,
|
||||
u: f64,
|
||||
choice: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DrawRecord {
|
||||
id_obs: usize,
|
||||
a: f64,
|
||||
b: f64,
|
||||
c: f64,
|
||||
}
|
||||
|
||||
fn run() -> Result<(), Error> {
|
||||
// Generate input.
|
||||
let nb_obs = 1000;
|
||||
let nb_alts = 4;
|
||||
let nb_params = 3;
|
||||
let nb_draws = 1000;
|
||||
|
||||
println!("Reading observations");
|
||||
|
||||
let file_path = "./data/generated.csv";
|
||||
let file = File::open(file_path)?;
|
||||
let mut rdr = csv::Reader::from_reader(file);
|
||||
let mut population = Vec::new();
|
||||
let mut obs = Observation {
|
||||
alternatives: Vec::new(),
|
||||
choice: 0,
|
||||
};
|
||||
let mut last_id_obs = 0;
|
||||
let mut j = 0;
|
||||
for result in rdr.deserialize() {
|
||||
let record: Record = result?;
|
||||
if record.id_obs != last_id_obs {
|
||||
population.push(obs.clone());
|
||||
obs = Observation {
|
||||
alternatives: Vec::new(),
|
||||
choice: 0,
|
||||
};
|
||||
last_id_obs = record.id_obs;
|
||||
j = 0;
|
||||
}
|
||||
obs.alternatives
|
||||
.push(array![record.a as f64, record.b as f64, record.c as f64]);
|
||||
if record.choice {
|
||||
obs.choice = j;
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
|
||||
println!("Reading draws");
|
||||
|
||||
let file_path = "./data/draws.csv";
|
||||
let file = File::open(file_path)?;
|
||||
let mut rdr = csv::Reader::from_reader(file);
|
||||
let mut draws = Array3::zeros((nb_obs, nb_draws, nb_params));
|
||||
let mut id_obs = 0;
|
||||
let mut id_draw = 0;
|
||||
for result in rdr.deserialize() {
|
||||
let record: DrawRecord = result?;
|
||||
if record.id_obs != id_obs {
|
||||
id_obs = record.id_obs;
|
||||
id_draw = 0;
|
||||
}
|
||||
draws
|
||||
.slice_mut(s![id_obs, id_draw, ..])
|
||||
.assign(&array![record.a, record.b, record.c,]);
|
||||
id_draw += 1;
|
||||
}
|
||||
|
||||
// Define cost function
|
||||
let mixed_logit = MixedLogit { population, draws };
|
||||
|
||||
let b = array![1.0, -1.0, 2.0];
|
||||
let w = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 2.0]];
|
||||
|
||||
let ll1 = mixed_logit.simloglike(&array![
|
||||
1.0, -1.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0
|
||||
]);
|
||||
println!("{}", ll1);
|
||||
|
||||
let ll2 = mixed_logit.simloglike(&array![
|
||||
1.0, -1.1, 2.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0
|
||||
]);
|
||||
println!("{}", ll2);
|
||||
|
||||
// Define initial parameter vector
|
||||
// let init_param: Array1<f64> = array![-1.2, 1.0];
|
||||
// let init_hessian: Array2<f64> = Array2::eye(2);
|
||||
let init_param: Array1<f64> =
|
||||
array![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
|
||||
let init_hessian: Array2<f64> = Array2::eye(12);
|
||||
|
||||
// set up a line search
|
||||
// let linesearch = MoreThuenteLineSearch::new().c(1e-4, 0.9)?;
|
||||
|
||||
// Set up solver
|
||||
// let solver = BFGS::new(init_hessian, linesearch);
|
||||
|
||||
// Run solver
|
||||
// let res = Executor::new(mixed_logit, solver, init_param)
|
||||
// .add_observer(ArgminSlogLogger::term(), ObserverMode::Always)
|
||||
// .max_iters(100)
|
||||
// .run()?;
|
||||
|
||||
// Wait a second (lets the logger flush everything before printing again)
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
|
||||
// Print result
|
||||
// println!("{}", res);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
if let Err(ref e) = run() {
|
||||
println!("{}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue