亚洲ww无码ww专区1234_亚洲AV综合色区无码三区30p_丰满二级精品一区_美女黄频视频大全免费的正片_久久综合九色综合网站

自制深度學(xué)習(xí)推理框架-實(shí)現(xiàn)我們的第一個算子Relu-第三課-世界時訊

2023-01-02 16:22:52    來源:
我們的課程主頁

https://github.com/zjhellofss/KuiperInfer 歡迎pr和點(diǎn)贊


(相關(guān)資料圖)

手把手教大家去寫一個深度學(xué)習(xí)推理框架 B站視頻課程

Relu算子的介紹

Relu是一種非線性激活函數(shù),它的特點(diǎn)有運(yùn)算簡單,不會在梯度處出現(xiàn)梯度消失的情況,而且它在一定程度上能夠防止深度學(xué)習(xí)模型在訓(xùn)練中發(fā)生的過擬合現(xiàn)象。Relu的公式表達(dá)如下所示,「如果對于深度學(xué)習(xí)基本概念不了解的同學(xué),可以將Relu當(dāng)作一個公式進(jìn)行對待,可以不用深究其背后的含義。」

我們今天的任務(wù)就是來完成這個公式中的操作,「值得注意的是,在我們的項(xiàng)目中,x和y可以理解為我們在第二、第三節(jié)中實(shí)現(xiàn)的張量類(tensor).」

Operator類

Operator類就是我們在第一節(jié)中說過的計(jì)算圖中「節(jié)點(diǎn)」的概念,計(jì)算圖的另外一個概念是數(shù)據(jù)流圖,如果同學(xué)們忘記了這個概念,可以重新重新翻看第一節(jié)課程。

在我們的代碼中我們先定義一個「Operator」類,它是一個父類,其余的Operator,包括我們本節(jié)要實(shí)現(xiàn)的ReluOperator都是其派生類,「Operator中會存放節(jié)點(diǎn)相關(guān)的參數(shù)?!估缭凇窩onvOperator」中就會存放初始化卷積算子所需要的stride, padding, kernel_size等信息,本節(jié)的「ReluOperator」就會帶有「thresh」值信息。

我們從下方的代碼中來了解Operator類和ReluOperator類,它們是父子關(guān)系,Operator是基類,OpType記錄Operator的類型。

enumclassOpType{kOperatorUnknown=-1,kOperatorRelu=0,};classOperator{public:OpTypekOpType=OpType::kOperatorUnknown;virtual~Operator()=default;explicitOperator(OpTypeop_type);};

ReluOperator實(shí)現(xiàn):

classReluOperator:publicOperator{public:~ReluOperator()override=default;explicitReluOperator(floatthresh);voidset_thresh(floatthresh);floatget_thresh()const;private:floatthresh_=0.f;};

Layer類

我們會在operator類中存放從「計(jì)算圖結(jié)構(gòu)文件」得到的信息,例如在ReluOperator中存放的thresh值作為一個參數(shù)就是我們從計(jì)算圖結(jié)構(gòu)文件中得到的,計(jì)算圖相關(guān)的概念我們已經(jīng)在第一節(jié)中講過。

下一步我們需要根據(jù)ReLuOperator類去完成ReluLayer的初始化,「他們的區(qū)別在于ReluOperator負(fù)責(zé)存放從計(jì)算圖中得到的節(jié)點(diǎn)信息,不負(fù)責(zé)計(jì)算」,而ReluLayer則「負(fù)責(zé)具體的計(jì)算操作」,同樣,所有的Layer類有一個公共父類Layer. 我們可以從下方的代碼中來了解兩者的關(guān)系。

classLayer{public:explicitLayer(conststd::string&layer_name);virtualvoidForwards(conststd::vector>>&inputs,std::vector>>&outputs);virtual~Layer()=default;private:std::stringlayer_name_;};

其中Layer的Forwards方法是具體的執(zhí)行函數(shù),負(fù)責(zé)將輸入的inputs中的數(shù)據(jù),進(jìn)行relu運(yùn)算并存放到對應(yīng)的outputs中。

classReluLayer:publicLayer{public:~ReluLayer()override=default;explicitReluLayer(conststd::shared_ptr&op);voidForwards(conststd::vector>>&inputs,std::vector>>&outputs)override;private:std::shared_ptrop_;};

這是集成于Layer的ReluLayer類,我們可以看到其中有一個op成員,是一個ReluOperator指針,「這個指針中負(fù)責(zé)存放ReluLayer計(jì)算時所需要用到的一些參數(shù)」。此處op_存放的參數(shù)比較簡單,只有ReluOperator中的thresh參數(shù)。

我們再看看是怎么使用ReluOperator去初始化ReluLayer的,先通過統(tǒng)一接口傳入Operator類,再轉(zhuǎn)換為對應(yīng)的ReluOperator指針,最后再通過指針中存放的信息去初始化「op_」.

ReluLayer::ReluLayer(conststd::shared_ptr&op):Layer("Relu"){CHECK(op->kOpType==OpType::kOperatorRelu);ReluOperator*relu_op=dynamic_cast(op.get());CHECK(relu_op!=nullptr);this->op_=std::make_shared(relu_op->get_thresh());}

我們來看一下具體ReluLayer的Forwards過程,它在執(zhí)行具體的計(jì)算,完成Relu函數(shù)描述的功能。

voidReluLayer::Forwards(conststd::vector>>&inputs,std::vector>>&outputs){CHECK(this->op_!=nullptr);CHECK(this->op_->kOpType==OpType::kOperatorRelu);constuint32_tbatch_size=inputs.size();for(inti=0;iempty());conststd::shared_ptr>&input_data=inputs.at(i);input_data->data().transform([&](floatvalue){floatthresh=op_->get_thresh();if(value>=thresh){returnvalue;}else{return0.f;}});outputs.push_back(input_data);}}

在for循環(huán)中,首先讀取輸入input_data, 再對input_data使用armadillo自帶的transform按照我們給定的thresh過濾其中的元素,如果「value」的值大于thresh則不變,如果小于thresh就返回0.

最后,我們寫一個測試函數(shù)來驗(yàn)證我們以上的兩個類,節(jié)點(diǎn)op類,計(jì)算層layer類的正確性。先判斷Forwards返回的outputs是否已經(jīng)保存了relu層的輸出,輸出大小應(yīng)該assert為1. 隨后再進(jìn)行比對,我們應(yīng)該知道在thresh等于0的情況下,第一個輸出index(0)和第二個輸出index(1)應(yīng)該是0,第三個輸出應(yīng)該是3.f.

TEST(test_layer,forward_relu){usingnamespacekuiper_infer;floatthresh=0.f;std::shared_ptrrelu_op=std::make_shared(thresh);std::shared_ptr>input=std::make_shared>(1,1,3);input->index(0)=-1.f;input->index(1)=-2.f;input->index(2)=3.f;std::vector>>inputs;std::vector>>outputs;inputs.push_back(input);ReluLayerlayer(relu_op);layer.Forwards(inputs,outputs);ASSERT_EQ(outputs.size(),1);for(inti=0;iindex(0),0.f);ASSERT_EQ(outputs.at(i)->index(1),0.f);ASSERT_EQ(outputs.at(i)->index(2),3.f);}}

本期代碼倉庫位置

gitclonehttps://gitee.com/fssssss/KuiperCourse.gitgitcheckoutfouth

關(guān)鍵詞: 是否已經(jīng) 卷積算子 激活函數(shù)

上一篇:

下一篇:

X 關(guān)閉

財(cái)經(jīng) 查看更多
安陽曹操墓將于2022年5月正式對公眾開放
時間·2021-12-29    來源·中新網(wǎng)
為什么這次寒潮南方降雪這么明顯?
時間·2021-12-29    來源·新華社

X 關(guān)閉