深度学习万能数据加载代码

作者: 小墨 分类: 深度学习 发布时间: 2021-09-04 00:14 访问量:13,758
FavoriteLoading收藏

#此函数只要稍作修改,就可以加载任意自己想要的数据格式。
def data_loader(data,seq_length,batch_size):
    def reader():
        data_list = []
        label_list = []
        for i in range(len(data)-seq_length):
            data_list.append(data[i:i+seq_length,:-1])
            label_list.append(data[i:i+seq_length,-1])
            if len(data_list) == batch_size:
                data_array = torch.tensor(np.array(data_list), dtype=torch.float)
                label_array = torch.tensor(np.array(label_list), dtype=torch.float)
                yield data_array, label_array
                data_list = []
                label_list = []
        if len(data_list) > 0:
            data_array = torch.tensor(np.array(data_list), dtype=torch.float)
            label_array = torch.tensor(np.array(label_list), dtype=torch.float)
            yield data_array, label_array
    return reader

     

如果觉得小墨的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

9条评论

发表评论