Skip to content

AI编程:HW2 Tensor

· 3 min

一个 Tensor,是一个多维数组。它是深度学习中数据的基本表示形式。Tensor 可以有不同的维度(rank),例如标量(0维)、向量(1维)、矩阵(2维)以及更高维度的数组。

本质上讲,Tensor 是内存(Host 或者 Device)中的一块连续的存储空间。我们的Tensor类需要管理这块内存,并提供一些基本的操作接口。

参考 dmlc/mshadow (一个轻量级的C++张量库) 的设计。

Tensor 类设计#

在HW2中,我们设计一个简单的 Tensor 类,支持以下功能:

Tensor.h
#pragma once
#include <vector>
#include <memory>
#include <cstddef>
// 定义一个枚举类来清晰地表示设备类型
enum class Device {
kCPU,
kGPU
};
class Tensor {
public:
// 构造函数:根据指定的形状和设备分配内存
Tensor(const std::vector<size_t>& shape, Device device = Device::kCPU);
// 默认析构函数即可,因为智能指针会自动管理内存
~Tensor() = default;
// 拷贝构造函数:创建一个新的 Tensor,并深拷贝数据
Tensor(const Tensor& other);
// 赋值运算符
Tensor& operator=(const Tensor& other);
// 将数据移动到 GPU 的方法
Tensor gpu() const;
// 将数据移动到 CPU 的方法
Tensor cpu() const;
// --- 辅助函数 ---
// 获取设备类型
Device device() const { return device_; }
// 获取形状
const std::vector<size_t>& shape() const { return shape_; }
// 获取元素总数
size_t numel() const { return num_elements_; }
// 获取指向数据的原始指针(用于 CUDA 核函数或库函数调用)
float* data() const { return data_ptr_.get(); }
private:
// 使用智能指针管理内存
std::shared_ptr<float> data_ptr_;
// 存储张量的形状
std::vector<size_t> shape_;
// 元素的总数
size_t num_elements_;
// 当前张量所在的设备
Device device_;
};

设计细节#

构造函数#

为了区分 Tensor 所在的设备,我们定义一个 Device 枚举类,包含 kCPUkGPU 两个值。

作业要求我们这么设计接口,且要求使⽤ C++ 标准库中提供的容器定义 Tensor 的形状。

Tensor tensor(shape, device);

我们这么设计构造函数:

Tensor(const std::vector<size_t>& shape, Device device = Device::kCPU);

如果不指定设备,则默认创建在 CPU 上。

内存管理#

作业要求建议使用智能指针来管理内存,避免内存泄漏。我们使用 std::shared_ptr<float> 来管理 Tensor 的数据内存。