use core::fmt::{Display, Formatter}; use std::collections::HashMap; use std::error::Error as ErrorTrait; use std::fs::File; use std::path::PathBuf; use crate::classfile::{ JavaClassFile, AbstractTypeDescription, AbstractTypeKind }; use crate::classfile; use crate::heap_area::ObjectReference; use crate::iterators::CompatibleTypesIterator; #[derive(Debug)] pub struct ClassStore { pub class_ids: HashMap, pub array_classes: HashMap, pub classes: Vec, pub class_path_fragments: Vec, pub native_class_names: Vec, pub primitive_classes: PrimitiveClassStore, } #[derive(Debug, Default)] pub struct PrimitiveClassStore { pub byte_class: ObjectReference, pub char_class: ObjectReference, pub double_class: ObjectReference, pub float_class: ObjectReference, pub int_class: ObjectReference, pub long_class: ObjectReference, pub short_class: ObjectReference, pub boolean_class: ObjectReference, } #[derive(Debug)] pub struct ClassStoreEntry { was_init: bool, class_object: ObjectReference, class_file: JavaClassFile, } #[derive(Debug)] pub enum Error { ClassNotFoundError(String), IOError(std::io::Error), ClassFileError(String, classfile::Error), } impl From for Error { fn from(value: std::io::Error) -> Self { return Error::IOError( value ); } } impl From for Error { fn from(value: classfile::Error) -> Self { return Error::ClassFileError( "An error occured while loading a classfile".to_string(), value ); } } impl ErrorTrait for Error {} impl Display for Error { fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { writeln!(formatter, "{self}")?; if let Some(e) = self.source() { writeln!(formatter, "\tCaused by: {e:?}")?; } Ok(()) } } impl ClassStore { pub fn new() -> Self { let current_dir_path = PathBuf::from("./"); ClassStore { class_ids: HashMap::new(), array_classes: HashMap::new(), classes: Vec::new(), class_path_fragments: vec![current_dir_path], native_class_names: Vec::new(), primitive_classes: PrimitiveClassStore::default(), } } pub fn class_count(&self) -> usize { return self.classes.len(); } pub fn add_class(&mut self, class_file: JavaClassFile, was_init: bool) -> Result { let classname = class_file.get_classname()?; self.class_ids.insert(classname.to_string(), self.classes.len()); let entry = ClassStoreEntry { was_init, class_object: ObjectReference::NULL, class_file }; self.classes.push(entry); return Ok(self.classes.len() - 1); } pub fn add_native_class_descriptor(&mut self, name: String) -> usize { self.native_class_names.push(name); return self.native_class_names.len() - 1; } pub fn get_native_class_name(&self, index: usize) -> &String { return &self.native_class_names[index]; } pub fn load_class_from_file(&mut self, class_file_path: &PathBuf) -> Result { let mut file_reader = File::open(class_file_path)?; let class_file = JavaClassFile::new(&mut file_reader)?; return self.add_class(class_file, false); } pub fn load_class(&mut self, classname: &String) -> Result { let mut path_buf = PathBuf::new(); for class_path in &self.class_path_fragments { path_buf.push(class_path); path_buf.push(&classname); path_buf.set_extension("class"); if path_buf.is_file() { return self.load_class_from_file(&path_buf); } path_buf.clear(); }; return Err(Error::ClassNotFoundError(format!("Could not find class '{classname}' in classpath"))); } pub fn are_types_compatible(&self, my_type: &AbstractTypeDescription, other_type: &AbstractTypeDescription) -> bool { if my_type == other_type { return true; } if my_type.array_level != other_type.array_level { return false; } if my_type.kind == other_type.kind { return true; } let my_type_name = match &my_type.kind { AbstractTypeKind::Classname(name) => name, _ => unreachable!(), }; let my_type_index = self.class_idx_from_name(&my_type_name).unwrap(); let other_type_name = match &other_type.kind { AbstractTypeKind::Classname(name) => name, _ => unreachable!(), }; let compatible_count = CompatibleTypesIterator::new(my_type_index, self) .filter(|type_name| *type_name == other_type_name ) .count(); compatible_count != 0 } pub fn have_class(&self, classname: &String) -> bool { return self.class_ids.contains_key(classname); } pub fn get_class(&self, classname: &String) -> Result<(&JavaClassFile, usize), Error> { let class_id = self.class_ids.get(classname); return match class_id { Some(id) => Ok((&self.classes[*id].class_file, *id)), None => Err(Error::ClassNotFoundError(format!("Could not locate class '{}'", classname))), } } pub fn get_or_load_class(&mut self, classname: &String) -> Result<(&JavaClassFile, usize), Error> { if self.have_class(classname) { return Ok(self.get_class(classname)?); } else { let class_idx = self.load_class(classname)?; return Ok((&self.classes[class_idx].class_file, class_idx)); } } pub fn class_name_from_index(&self, index: usize) -> Option<&String> { return match self.class_file_from_idx(index) { Some(file) => Some(file.get_classname().unwrap()), None => None, } } pub fn class_file_from_idx(&self, idx: usize) -> Option<&JavaClassFile> { return match self.classes.get(idx) { Some(entry) => Some(&entry.class_file), None => None, } } pub fn class_idx_from_name(&self, classname: &String) -> Option { return self.class_ids.get(classname).copied(); } pub fn was_init(&self, classname: &String) -> Option { let entry = self.classes.get(self.class_idx_from_name(classname).unwrap()).unwrap(); return Some(entry.was_init); } pub fn set_init(&mut self, class_idx: usize, was_init: bool) { let entry = self.classes.get_mut(class_idx).unwrap(); entry.was_init = was_init; } pub fn put_array_class_ref(&mut self, type_desc: AbstractTypeDescription, class_ref: ObjectReference) { self.array_classes.insert(type_desc, class_ref); } pub fn set_class_objectref_by_index(&mut self, index: usize, class_objref: ObjectReference) { self.classes.get_mut(index).unwrap().class_object = class_objref; } pub fn get_class_objectref_from_index(&self, index: usize) -> ObjectReference { self.classes[index].class_object } pub fn get_class_objectref_from_primitive(&self, primitive: AbstractTypeKind) -> Option { match primitive { AbstractTypeKind::Boolean() => { Some(self.primitive_classes.boolean_class) } AbstractTypeKind::Byte() => { Some(self.primitive_classes.byte_class) } AbstractTypeKind::Char() => { Some(self.primitive_classes.char_class) } AbstractTypeKind::Double() => { Some(self.primitive_classes.double_class) } AbstractTypeKind::Float() => { Some(self.primitive_classes.float_class) } AbstractTypeKind::Int() => { Some(self.primitive_classes.int_class) } AbstractTypeKind::Long() => { Some(self.primitive_classes.long_class) } AbstractTypeKind::Short() => { Some(self.primitive_classes.short_class) } _ => todo!(), } } pub fn get_array_class_ref(&self, type_desc: &AbstractTypeDescription) -> Option { return self.array_classes.get(type_desc).copied(); } pub fn class_index_for_type(&self, r#type: AbstractTypeDescription) -> Option { match (r#type.array_level, &r#type.kind) { (0, AbstractTypeKind::Classname(ref name)) => Some(self.class_idx_from_name(name).unwrap()), _ => None } } pub fn class_ref_for_type(&self, r#type: AbstractTypeDescription) -> Option { match (r#type.array_level, &r#type.kind) { (0, AbstractTypeKind::Classname(ref name)) => { let class_index = self.class_idx_from_name(name).unwrap(); Some(self.get_class_objectref_from_index(class_index)) } (0, _) => { self.get_class_objectref_from_primitive(r#type.kind) } (1..=u8::MAX, _) => { self.get_array_class_ref(&r#type) } } } }