所以我的問(wèn)題是,我有一個(gè)層特性,輸入和輸出類型如下:
pub trait Layer {
type Input: Dimension;
type Output: Dimension;
fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}
使用此正向功能:
impl<A: Activation> Layer for DenseLayer<A> {
type Input = Ix2;
type Output = Ix2;
fn forward(&mut self, input: &Array2<f32>) -> Array2<f32> {
assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
let z = input.dot(&self.weights) + &self.biases;
self.activation.activate(&z)
}
}
我有這些,這樣我的前進(jìn)或后退函數(shù)就可以接收例如一個(gè)二維數(shù)組,但仍然輸出一個(gè)只有一維的數(shù)組。然后,我有一個(gè)這個(gè)層特性的包裝器的實(shí)現(xiàn),我想在其中轉(zhuǎn)發(fā)所有層:
pub struct NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
layers: Vec<L>,
loss_function: &'a dyn Cost,
}
impl<'a, L> NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
pub fn new(layers: Vec<L>, loss_function: &'a dyn Cost) -> Self {
NeuralNetwork { layers, loss_function }
}
pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
let mut output = input.clone();
// todo fix the layer forward changing input to output
// causing mismatch in the input and output dimensions of forward
for layer in &mut self.layers {
output = layer.forward(&output);
}
output
}
}
現(xiàn)在,因?yàn)樵趂or循環(huán)中,我首先輸入input類型,然后從layer.forward接收輸出。在下一次迭代中,它接受類型輸出,但layer.forward只接受類型輸入。至少,這就是我認(rèn)為正在發(fā)生的事情。這似乎是一個(gè)非常簡(jiǎn)單的問(wèn)題,但我真的不確定如何解決這個(gè)問(wèn)題。
Edit 1:
Reproduceable Example:
use ndarray::{Array, Array2, ArrayBase, Dimension, OwnedRepr};
pub trait Layer {
type Input: Dimension;
type Output: Dimension;
fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}
// A Dense Layer struct
pub struct DenseLayer {
weights: Array2<f32>,
biases: Array2<f32>,
}
impl DenseLayer {
pub fn new(input_size: usize, output_size: usize) -> Self {
let weights = Array::random((input_size, output_size), rand::distributions::Uniform::new(-0.5, 0.5));
let biases = Array::zeros((1, output_size));
DenseLayer { weights, biases }
}
}
impl Layer for DenseLayer {
type Input = ndarray::Ix2; // Two-dimensional input
type Output = ndarray::Ix2; // Two-dimensional output
fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output> {
assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
let z = input.dot(&self.weights) + &self.biases;
z // Return the output directly without activation
}
}
// Neural Network struct
pub struct NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
layers: Vec<L>,
}
impl<'a, L> NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
pub fn new(layers: Vec<L>) -> Self {
NeuralNetwork { layers }
}
pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
let mut output = input.clone();
for layer in &mut self.layers {
output = layer.forward(&output);
}
output
}
}
fn main() {
// Create a neural network with one Dense Layer
let mut dense_layer = DenseLayer::new(3, 2);
let mut nn = NeuralNetwork::new(vec![dense_layer]);
// Create an example input (1 batch, 3 features)
let input = Array::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
// Forward pass
let output = nn.forward(&input);
println!("Output: {:?}", output);
}
有兩件事你需要得到
NeuralNetwork::forward
才能編譯。Layer
綁定,以便Input
和Output
關(guān)聯(lián)的類型是相同的類型。Clone
,以便input.clone()
可以克隆底層數(shù)組,而不是克隆引用。這些邊界將把這些限制傳達(dá)給編譯器(注意在
impl
塊上引入了一個(gè)新的泛型參數(shù)T
):請(qǐng)注意,您應(yīng)該考慮將
NeuralNetwork::new
移動(dòng)到具有最小限制的impl
塊,因?yàn)闆]有理由需要對(duì)其應(yīng)用這些限制。還有其他一些compile-time錯(cuò)誤,但我認(rèn)為這些錯(cuò)誤與您試圖解決的問(wèn)題無(wú)關(guān)。特別是,我不清楚為什么你在
NeuralNetwork
上有'a
的生命;你可以完全刪除它,代碼仍然可以編譯。