trait對(duì)象的輸入和輸出類型之間的Rust不匹配

所以我的問(wèn)題是,我有一個(gè)層特性,輸入和輸出類型如下:

pub trait Layer {
    type Input: Dimension;
    type Output: Dimension;
    
    fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}

使用此正向功能:

impl<A: Activation> Layer for DenseLayer<A> {
    type Input = Ix2;
    type Output = Ix2;

    fn forward(&mut self, input: &Array2<f32>) -> Array2<f32> {
        assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
        let z = input.dot(&self.weights) + &self.biases;

        self.activation.activate(&z)
    }
}

我有這些,這樣我的前進(jìn)或后退函數(shù)就可以接收例如一個(gè)二維數(shù)組,但仍然輸出一個(gè)只有一維的數(shù)組。然后,我有一個(gè)這個(gè)層特性的包裝器的實(shí)現(xiàn),我想在其中轉(zhuǎn)發(fā)所有層:

pub struct NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    layers: Vec<L>,
    loss_function: &'a dyn Cost,
}

impl<'a, L> NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    pub fn new(layers: Vec<L>, loss_function: &'a dyn Cost) -> Self {
        NeuralNetwork { layers, loss_function }
    }

    pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
        let mut output = input.clone();

        // todo fix the layer forward changing input to output
        // causing mismatch in the input and output dimensions of forward
        for layer in &mut self.layers {
            output = layer.forward(&output);
        }

        output
    }
}

現(xiàn)在,因?yàn)樵趂or循環(huán)中,我首先輸入input類型,然后從layer.forward接收輸出。在下一次迭代中,它接受類型輸出,但layer.forward只接受類型輸入。至少,這就是我認(rèn)為正在發(fā)生的事情。這似乎是一個(gè)非常簡(jiǎn)單的問(wèn)題,但我真的不確定如何解決這個(gè)問(wèn)題。

Edit 1:

Reproduceable Example:

use ndarray::{Array, Array2, ArrayBase, Dimension, OwnedRepr};

pub trait Layer {
    type Input: Dimension;
    type Output: Dimension;

    fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}

// A Dense Layer struct
pub struct DenseLayer {
    weights: Array2<f32>,
    biases: Array2<f32>,
}

impl DenseLayer {
    pub fn new(input_size: usize, output_size: usize) -> Self {
        let weights = Array::random((input_size, output_size), rand::distributions::Uniform::new(-0.5, 0.5));
        let biases = Array::zeros((1, output_size));
        DenseLayer { weights, biases }
    }
}

impl Layer for DenseLayer {
    type Input = ndarray::Ix2;  // Two-dimensional input
    type Output = ndarray::Ix2; // Two-dimensional output

    fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output> {
        assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
        let z = input.dot(&self.weights) + &self.biases;
        z // Return the output directly without activation
    }
}

// Neural Network struct
pub struct NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    layers: Vec<L>,
}

impl<'a, L> NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    pub fn new(layers: Vec<L>) -> Self {
        NeuralNetwork { layers }
    }

    pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
        let mut output = input.clone();

        for layer in &mut self.layers {
            output = layer.forward(&output);
        }

        output
    }
}

fn main() {
    // Create a neural network with one Dense Layer
    let mut dense_layer = DenseLayer::new(3, 2);
    let mut nn = NeuralNetwork::new(vec![dense_layer]);

    // Create an example input (1 batch, 3 features)
    let input = Array::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
    
    // Forward pass
    let output = nn.forward(&input);
    println!("Output: {:?}", output);
}

? 最佳回答:

有兩件事你需要得到NeuralNetwork::forward才能編譯。

  • 您需要限制Layer綁定,以便InputOutput關(guān)聯(lián)的類型是相同的類型。
  • 您需要確保此類型實(shí)現(xiàn)Clone,以便input.clone()可以克隆底層數(shù)組,而不是克隆引用。

這些邊界將把這些限制傳達(dá)給編譯器(注意在impl塊上引入了一個(gè)新的泛型參數(shù)T):

impl<'a, L, T> NeuralNetwork<'a, L>
where
    L: Layer<Input = T, Output = T> + 'a,
    T: Clone,

請(qǐng)注意,您應(yīng)該考慮將NeuralNetwork::new移動(dòng)到具有最小限制的impl塊,因?yàn)闆]有理由需要對(duì)其應(yīng)用這些限制。


還有其他一些compile-time錯(cuò)誤,但我認(rèn)為這些錯(cuò)誤與您試圖解決的問(wèn)題無(wú)關(guān)。特別是,我不清楚為什么你在NeuralNetwork上有'a的生命;你可以完全刪除它,代碼仍然可以編譯。

主站蜘蛛池模板: 日韩电影一区二区三区| 亚洲精品一区二区三区四区乱码| 亚洲国产福利精品一区二区| 无码8090精品久久一区| 亚洲综合av一区二区三区| 精品乱人伦一区二区三区| 国产亚洲福利一区二区免费看 | 午夜AV内射一区二区三区红桃视 | 久久青草精品一区二区三区| 国产丝袜无码一区二区视频| 欧美日韩国产免费一区二区三区 | 国产成人av一区二区三区在线| 午夜影视日本亚洲欧洲精品一区 | 亚洲日韩国产精品第一页一区| 激情综合丝袜美女一区二区| 精品国产高清自在线一区二区三区 | 亚洲欧美日韩国产精品一区| 无码乱码av天堂一区二区| 精品国产一区在线观看| 伊人无码精品久久一区二区| 色噜噜狠狠一区二区| 成人精品一区二区户外勾搭野战| 亚欧成人中文字幕一区 | 精品无码人妻一区二区三区| 福利一区二区三区视频在线观看| 免费观看一区二区三区| 精品少妇一区二区三区视频| 日产精品久久久一区二区| 精品无码一区二区三区亚洲桃色| 国产一区二区三区影院| 无码日韩精品一区二区免费| 理论亚洲区美一区二区三区| 日本一区二区三区不卡视频中文字幕| 香蕉久久AⅤ一区二区三区| 色精品一区二区三区| 激情一区二区三区| 精品国产一区二区三区色欲| 国产麻豆剧果冻传媒一区| 久久精品无码一区二区三区日韩| 亚洲一区二区三区偷拍女厕| 国产美女一区二区三区|