需求
- 计算两个长度为2的幂次方的向量的对应位置相乘相加结果
- 输入为补码,输出为补码(支持负数)
- 输入位宽可配置,输入向量的宽度可配置,输出位宽由以上两项决定
设计规划
参数表
参数名称 | 说明 | 默认值 |
---|---|---|
DIN_WIDTH | 输入位宽 | 8 |
DIN_NUM_LOG | 输入向量的宽度的log2值(宽度$$2^{DIN_NUM_LOG}$$) | 2 |
注:输出位宽由以上决定,为$$DOUT_WIDTH = DIN_WIDTH \times 2 + DIN_NUM_LOG - 1$$
端口列表
端口名 | 类型 | 位宽 | 说明 |
---|---|---|---|
clk | input | 1 | 系统时钟 |
rst_n | input | 1 | 系统复位 |
din_valid | input | 1 | 输入数据有效,高有效 |
mla_din1 | input | (2 ** DIN_NUM_LOG) * DIN_WIDTH | 输入向量1 |
mla_din2 | input | (2 ** DIN_NUM_LOG) * DIN_WIDTH | 输入向量2 |
dout_valid | output | 1 | 输出信号有效,高有效 |
mla_dout | output | DIN_WIDTH * 2 + DIN_NUM_LOG - 1 | 输出结果 |
功能描述
功能
$mla_dout = \sum_{i = 0}{2{DIN_NUM_LOG}} mla_din1[i] \times mla_din2[i] $
其中,mla_din1[i]和mla_din2[i]按位宽存储在输入mla_din1和mla_din2,每个均为补码中,如图:
时序
- 当输入有效din_valid有效时,开始计算;当dout_valid有效时,结果有效。
- 设计为流水线式,输入可以连续送入,输出可以连续输出,输出具有保序性。
- 输入有效到输出有效的间隔由DIN_NUM_LOG决定
结构
以DIN_NUM_LOG=2为例:
- 输入的mla_din在内部被分解并对应相乘
- 使用每层带寄存器的加法树实现累加
- 有效信号随对应数据流动
代码实现
RTL设计
module声明
module mla_tree #(
parameter DIN_WIDTH = 8,
parameter DIN_NUM_LOG = 2
)(
input clk,
input rst_n,
input din_valid,
input [(2 ** DIN_NUM_LOG) * DIN_WIDTH - 1:0]mla_din1,
input [(2 ** DIN_NUM_LOG) * DIN_WIDTH - 1:0]mla_din2,
output dout_valid,
output [DIN_WIDTH * 2 + DIN_NUM_LOG - 2:0]mla_dout
);
解码输入
- 将输入的向量解码进入数组,方便代码编写
- 处理输入数据,将补码形式转为“符号位-原码”表示,便于使用无符号乘法器
- din_valid随数据流动,产生unpack_valid
reg [DIN_WIDTH - 1:0]din1_unpack[2 ** DIN_NUM_LOG - 1:0];
reg [DIN_WIDTH - 1:0]din2_unpack[2 ** DIN_NUM_LOG - 1:0];
integer j;
always @(posedge clk or negedge rst_n) begin
if (~rst_n) begin
for (j = 0; j < 2 ** DIN_NUM_LOG ; j = j + 1) begin
din1_unpack[j] <= 'b0;
din2_unpack[j] <= 'b0;
end
end else if (din_valid)begin
for (j = 0; j < 2 ** DIN_NUM_LOG ; j = j + 1) begin
din1_unpack[j][DIN_WIDTH - 1] <= mla_din1[DIN_WIDTH * (j + 1) - 1];
din1_unpack[j][DIN_WIDTH - 2:0] <= (mla_din1[DIN_WIDTH * (j + 1) - 1])?(~mla_din1[DIN_WIDTH * j +:DIN_WIDTH - 1]+1'b1):mla_din1[DIN_WIDTH * j +:DIN_WIDTH - 1];
din2_unpack[j][DIN_WIDTH - 1] <= mla_din2[DIN_WIDTH * (j + 1) - 1];
din2_unpack[j][DIN_WIDTH - 2:0] <= (mla_din2[DIN_WIDTH * (j + 1) - 1])?(~mla_din2[DIN_WIDTH * j +:DIN_WIDTH - 1]+1'b1):mla_din2[DIN_WIDTH * j +:DIN_WIDTH - 1];
end
end
end
reg unpack_valid;
always @ (posedge clk or negedge rst_n) begin
if (~rst_n) begin
unpack_valid <= 'b0;
end else begin
unpack_valid <= din_valid;
end
end
乘法器
- 将输入数值部分相乘,同时手动控制结果的符号位
- unpack_valid随数据流动,产生mul_valid
integer y;
reg [2 * DIN_WIDTH - 2:0]mul_result[2 ** DIN_NUM_LOG - 1:0];
always @ (posedge clk or negedge rst_n) begin
if (~rst_n) begin
for (y = 0; y < 2 ** DIN_NUM_LOG ; y = y + 1) begin
mul_result[y] <= 'b0;
end
end else if(unpack_valid)begin
for (y = 0; y < 2 ** DIN_NUM_LOG ; y = y + 1) begin
if (din1_unpack[y] != 'b0 && din2_unpack[y] != 'b0) begin
mul_result[y][2 * DIN_WIDTH - 2] <= din1_unpack[y][DIN_WIDTH - 1] ^ din2_unpack[y][DIN_WIDTH - 1];
end else begin
mul_result[y][2 * DIN_WIDTH - 2] <= 'b0;
end
mul_result[y][2 * DIN_WIDTH - 3:0] <= din1_unpack[y][DIN_WIDTH - 2:0] * din2_unpack[y][DIN_WIDTH - 2:0];
end
end
end
reg mul_valid;
always @ (posedge clk or negedge rst_n) begin
if (~rst_n) begin
mul_valid <= 'b0;
end else begin
mul_valid <= unpack_valid;
end
end
加法树
- 将输入转换为补码,之后逐级相加
- mul_valid逐级流动,产生每层的layer_dout_valid
genvar i,k;
generate
for (k = DIN_NUM_LOG; k > 0; k = k - 1) begin:mla_layer
wire [2 * DIN_WIDTH - 2 + (DIN_NUM_LOG - k):0]layer_din[2 ** k - 1:0];
wire layer_din_valid;
reg [2 * DIN_WIDTH - 1 + (DIN_NUM_LOG - k):0]layer_dout[2 ** (k - 1) - 1:0];
reg layer_dout_valid;
integer x;
if (k == DIN_NUM_LOG) begin
for (i = 0; i < 2 ** k ; i = i + 1) begin
assign layer_din[i][2 * DIN_WIDTH - 3:0] = (mul_result[i][2 * DIN_WIDTH - 2])?(~mul_result[i][2 * DIN_WIDTH - 3:0] + 1'b1):mul_result[i][2 * DIN_WIDTH - 3:0];
assign layer_din[i][2 * DIN_WIDTH - 2] = mul_result[i][2 * DIN_WIDTH - 2];
end
assign layer_din_valid = mul_valid;
end else begin
for (i = 0; i < 2 ** k; i = i + 1) begin
assign layer_din[i] = mla_layer[k + 1].layer_dout[i];
end
assign layer_din_valid = mla_layer[k + 1].layer_dout_valid;
end
always @ (posedge clk or negedge rst_n) begin
if (~rst_n) begin
for (x = 0; x < 2 ** (k - 1) ; x = x + 1) begin
layer_dout[x] <= 'b0;
end
end else if (layer_din_valid) begin
for (x = 0; x < 2 ** (k - 1) ; x = x + 1) begin
layer_dout[x] <= {layer_din[2 * x][2 * DIN_WIDTH - 2 + (DIN_NUM_LOG - k)],layer_din[2 * x]} + {layer_din[2 * x + 1][2 * DIN_WIDTH - 2 + (DIN_NUM_LOG - k)],layer_din[2 * x + 1]};
// layer_dout[x] <= layer_din[2 * x] + layer_din[2 * x + 1];
end
end
end
always @ (posedge clk or negedge rst_n) begin
if (~rst_n) begin
layer_dout_valid <= 'b0;
end else begin
layer_dout_valid <= layer_din_valid;
end
end
end
endgenerate
取出结果
从循环生成语句的最后一级取出输出
assign dout_valid = mla_layer[1].layer_dout_valid;
assign mla_dout = mla_layer[1].layer_dout[0];
endmodule // mla_tree
Testbench
dut声明与连接
module tb_mla_tree();
parameter DIN_WIDTH = 8;
parameter DIN_NUM_LOG = 4;
logic clk;
logic rst_n;
logic din_valid;
logic [(2 ** DIN_NUM_LOG) * DIN_WIDTH - 1:0]mla_din1;
logic [(2 ** DIN_NUM_LOG) * DIN_WIDTH - 1:0]mla_din2;
logic dout_valid;
logic [DIN_WIDTH * 2 + DIN_NUM_LOG - 2:0]mla_dout;
mla_tree #(
.DIN_WIDTH (DIN_WIDTH),
.DIN_NUM_LOG(DIN_NUM_LOG)
) dut (
.clk (clk),
.rst_n (rst_n),
.din_valid (din_valid),
.mla_din1 (mla_din1),
.mla_din2 (mla_din2),
.dout_valid(dout_valid),
.mla_dout (mla_dout)
);
时钟与复位信号
initial begin
clk = 0;
forever begin
#50 clk = ~clk;
end
end
initial begin
rst_n = 1'b1;
#5 rst_n = 1'b0;
#10 rst_n = 1'b1;
end
激励产生
激励产生函数
function logic[DIN_WIDTH - 1:0] data_random();
integer x;
x = (DIN_WIDTH)'($urandom_range(0,2 ** DIN_WIDTH));
if (x[DIN_WIDTH - 1] == 1'b1 && x[DIN_WIDTH - 2:0] == 'b0) begin
return 0;//不产生 -2 ** (DIN_WIDTH - 1)
end else begin
return x;
end
endfunction
激励产生
initial begin
din_valid = 'b0;
repeat(100) begin
for (int i = 0; i < 2 ** DIN_NUM_LOG ; i++) begin
mla_din1[i * DIN_WIDTH +:DIN_WIDTH] = data_random();
mla_din2[i * DIN_WIDTH +:DIN_WIDTH] = data_random();
end
@(negedge clk);
din_valid = 1;
end
$stop;
end
参考模型
输入解码函数
将补码形式的logic数据转为带符号的integer类型数据
function integer decode(logic[DIN_WIDTH - 1:0] din);
if(din[DIN_WIDTH - 1] == 1'b1) begin
return integer'(din) - 2 ** DIN_WIDTH;
end else begin
return integer'(din);
end
endfunction
参考模型
integer tb_din1[2 ** DIN_NUM_LOG - 1:0];
integer tb_din2[2 ** DIN_NUM_LOG - 1:0];
integer tb_result;
integer scoreboard[$]; //计分板
initial begin
forever begin
@(posedge clk);
if(din_valid == 1'b1) begin
// 获取输入数据
for (int i = 0; i < 2 ** DIN_NUM_LOG ; i++) begin
tb_din1[i] = decode(mla_din1[i * DIN_WIDTH +:DIN_WIDTH]);
tb_din2[i] = decode(mla_din2[i * DIN_WIDTH +:DIN_WIDTH]);
end
// 计算参考结果
tb_result = 0;
for (int i = 0; i < 2 ** DIN_NUM_LOG ; i++) begin
tb_result = tb_result + tb_din1[i] * tb_din2[i];
end
// 将参考结果送入计分板
scoreboard.push_back(tb_result);
end
end
end
计分板
integer tb_compare;
initial begin
forever begin
@(negedge clk);
if (dout_valid) begin
tb_compare = scoreboard.pop_front();
if ((DIN_WIDTH * 2 + DIN_NUM_LOG - 1)'(tb_compare) == mla_dout) begin
$display("%d == %d",tb_compare,integer'(mla_dout));
end else begin
$display("%h != %h",(DIN_WIDTH * 2 + DIN_NUM_LOG - 1)'(tb_compare),mla_dout);
$stop;
end
end
end
end