启明办公

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 97|回复: 0

4. 图数据处理管道

[复制链接]

2

主题

6

帖子

10

积分

新手上路

Rank: 1

积分
10
发表于 2023-1-17 04:16:45 | 显示全部楼层 |阅读模式
0.介绍

本内容主要来自于DGL官网:第4章:图数据处理管道 - DGL 0.9.1post1 documentation
DGL在 dgl.data 里实现了很多常用的图数据集。它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道。 DGL推荐用户将图数据处理为 dgl.data.DGLDataset 的子类。该类为导入、处理和保存图数据提供了简单而干净的解决方案。本章介绍了如何为用户自己的图数据创建一个DGL数据集。
1.DGLDataset类

DGLDataset 是处理、导入和保存 dgl.data 中定义的图数据集的基类。 它实现了用于处理图数据的基本模版。下面的流程图展示了这个模版的工作方式。


为了处理位于远程服务器或本地磁盘上的图数据集,下面的例子中定义了一个类,称为 MyDataset, 它继承自 dgl.data.DGLDataset。
from dgl.data import DGLDataset
class MyDataset(DGLDataset):
    """ 用于在DGL中自定义图数据集的模板:

    Parameters
    ----------
    url : str
        下载原始数据集的url。
    raw_dir : str
        指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/
    save_dir : str
        处理完成的数据集的保存目录。默认:raw_dir指定的值
    force_reload : bool
        是否重新导入数据集。默认:False
    verbose : bool
        是否打印进度信息。
    """
    def __init__(self,
                 url=None,
                 raw_dir=None,
                 save_dir=None,
                 force_reload=False,
                 verbose=False):
        super(MyDataset, self).__init__(name='dataset_name',
                                        url=url,
                                        raw_dir=raw_dir,
                                        save_dir=save_dir,
                                        force_reload=force_reload,
                                        verbose=verbose)

    def download(self):
        # 将原始数据下载到本地磁盘
        pass
    def process(self):
        # 将原始数据处理为图、标签和数据集划分的掩码
        pass
    def __getitem__(self, idx):
        # 通过idx得到与之对应的一个样本
        pass
    def __len__(self):
        # 数据样本的数量
        pass
    def save(self):
        # 将处理后的数据保存至 `self.save_path`
        pass
    def load(self):
        # 从 `self.save_path` 导入处理后的数据
        pass
    def has_cache(self):
        # 检查在 `self.save_path` 中是否存有处理后的数据
        passDGLDataset 类有抽象函数 process(), __getitem__(idx) 和 __len__()。子类必须实现这些函数。同时DGL也建议实现保存和导入函数, 因为对于处理后的大型数据集,这么做可以节省大量的时间, 并且有多个已有的API可以简化此操作。请注意, DGLDataset 的目的是提供一种标准且方便的方式来导入图数据。 用户可以存储有关数据集的图、特征、标签、掩码,以及诸如类别数、标签数等基本信息。 诸如采样、划分或特征归一化等操作建议在 DGLDataset 子类之外完成。
2.下载原始数据(可选)

如果用户的数据集已经在本地磁盘中,请确保它被存放在目录 raw_dir 中。 如果用户想在任何地方运行代码而又不想自己下载数据并将其移动到正确的目录中,则可以通过实现函数 download() 来自动完成。
如果数据集是一个zip文件,可以直接继承 dgl.data.DGLBuiltinDataset 类。后者支持解压缩zip文件。 否则用户需要自己实现 download(),具体可以参考 QM7bDataset 类:
import os
from dgl.data.utils import download

def download(self):
    # 存储文件的路径
    file_path = os.path.join(self.raw_dir, self.name + '.mat')
    # 下载文件
    download(self.url, path=file_path)上面的代码将一个.mat文件下载到目录 self.raw_dir。如果文件是.gz、.tar、.tar.gz或.tgz文件,请使用 extract_archive() 函数进行解压缩。以下代码展示了如何在 BitcoinOTCDataset 类中下载一个.gz文件:
from dgl.data.utils import download, check_sha1

def download(self):
    # 存储文件的路径,请确保使用与原始文件名相同的后缀
    gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
    # 下载文件
    download(self.url, path=gz_file_path)
    # 检查 SHA-1
    if not check_sha1(gz_file_path, self._sha1_str):
        raise UserWarning('File {} is downloaded but the content hash does not match.'
                          'The repo may be outdated or download may be incomplete. '
                          'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
    # 将文件解压缩到目录self.raw_dir下的self.name目录中
    self._extract_gz(gz_file_path, self.raw_path)上面的代码会将文件解压缩到 self.raw_dir 下的目录 self.name 中。 如果该类继承自 dgl.data.DGLBuiltinDataset 来处理zip文件, 则它也会将文件解压缩到目录 self.name 中。
一个可选项是用户可以按照上面的示例检查下载后文件的SHA-1字符串,以防作者在远程服务器上更改了文件。
3.处理数据

用户可以在 process() 函数中实现数据处理。该函数假定原始数据已经位于 self.raw_dir 目录中。图上的机器学习任务通常有三种类型:整图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。本节重点介绍了处理图、特征和划分掩码的标准方法。用户指南将以内置数据集为例,并跳过从文件构建图的实现。
处理整图分类数据集
整图分类数据集与用小批次训练的典型机器学习任务中的大多数数据集类似。 因此,需要将原始数据处理为 dgl.DGLGraph 对象的列表和标签张量的列表。 此外,如果原始数据已被拆分为多个文件,则可以添加参数 split 以导入数据的特定部分。
下面以 QM7bDataset 为例
from dgl.data import DGLDataset

class QM7bDataset(DGLDataset):
    _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
           'datasets/qm7b.mat'
    _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'

    def __init__(self, raw_dir=None, force_reload=False, verbose=False):
        super(QM7bDataset, self).__init__(name='qm7b',
                                          url=self._url,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def process(self):
        mat_path = self.raw_path + '.mat'
        # 将数据处理为图列表和标签列表
        self.graphs, self.label = self._load_graph(mat_path)

    def __getitem__(self, idx):
        """ 通过idx获取对应的图和标签

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (dgl.DGLGraph, Tensor)
        """
        return self.graphs[idx], self.label[idx]

    def __len__(self):
        """数据集中图的数量"""
        return len(self.graphs)
函数 process() 将原始数据处理为图列表和标签列表。用户必须实现 __getitem__(idx) 和  __len__() 以进行迭代。 DGL建议让 __getitem__(idx) 返回如上面代码所示的元组 (图,标签)。 用户可以参考 QM7bDataset源代码 以获得 self._load_graph() 和 __getitem__ 的详细信息。
用户还可以向类添加属性以指示一些有用的数据集信息。在 QM7bDataset 中, 用户可以添加属性 num_tasks 来指示此多任务数据集中的预测任务总数:
@property
def num_tasks(self):
    """每个图的标签数,即预测任务数。"""
    return 14在编写完这些代码之后,用户可以按如下所示的方式来使用 QM7bDataset:
import dgl
import torch

from dgl.dataloading import GraphDataLoader

# 数据导入
dataset = QM7bDataset()
num_tasks = dataset.num_tasks

# 创建 dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)

# 训练
for epoch in range(100):
    for g, labels in dataloader:
        # 用户自己的训练代码
        pass训练整图分类模型的完整指南可以在 5.4 整图分类 中找到。
有关整图分类数据集的更多示例,用户可以参考 5.4 整图分类:

  • gindataset
  • minigcdataset
  • qm7bdata
  • tudata
处理节点分类数据集
与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。 DGL建议使用节点掩码来指定数据集的划分。 本节以内置数据集 CitationGraphDataset 为例:
此外,DGL推荐重新排列图的节点/边,使得相邻节点/边的ID位于邻近区间内。这个过程 可以提高节点/边的邻居的局部性,为后续在图上进行的计算与分析的性能改善提供可能。 DGL提供了名为 dgl.reorder_graph() 的API用于此优化。更多细节,请参考 下面例子中的 process() 的部分。
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url

class CitationGraphDataset(DGLBuiltinDataset):
    _urls = {
        'cora_v2' : 'dataset/cora_v2.zip',
        'citeseer' : 'dataset/citeseer.zip',
        'pubmed' : 'dataset/pubmed.zip',
    }

    def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
        assert name.lower() in ['cora', 'citeseer', 'pubmed']
        if name.lower() == 'cora':
            name = 'cora_v2'
        url = _get_dgl_url(self._urls[name])
        super(CitationGraphDataset, self).__init__(name,
                                                   url=url,
                                                   raw_dir=raw_dir,
                                                   force_reload=force_reload,
                                                   verbose=verbose)

    def process(self):
        # 跳过一些处理的代码
        # === 跳过数据处理 ===

        # 构建图
        g = dgl.graph(graph)

        # 划分掩码
        g.ndata['train_mask'] = train_mask
        g.ndata['val_mask'] = val_mask
        g.ndata['test_mask'] = test_mask

        # 节点的标签
        g.ndata['label'] = torch.tensor(labels)

        # 节点的特征
        g.ndata['feat'] = torch.tensor(_preprocess_features(features),
                                       dtype=F.data_type_dict['float32'])
        self._num_tasks = onehot_labels.shape[1]
        self._labels = labels
        # 重排图以获得更优的局部性
        self._g = dgl.reorder_graph(g)

    def __getitem__(self, idx):
        assert idx == 0, "这个数据集里只有一个图"
        return self._g

    def __len__(self):
        return 1为简便起见,这里省略了 process() 中的一些代码,以突出展示用于处理节点分类数据集的关键部分:划分掩码。 节点特征和节点的标签被存储在 g.ndata 中。详细的实现请参考 CitationGraphDataset源代码 。
请注意,这里 __getitem__(idx) 和 __len__() 的实现也发生了变化, 这是因为节点分类任务通常只用一个图。掩码在PyTorch和TensorFlow中是bool张量,在MXNet中是float张量。
下面中使用 dgl.data.CitationGraphDataset 的子类 dgl.data.CiteseerGraphDataset 来演示如何使用用于节点分类的数据集:
# 导入数据
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]

# 获取划分的掩码
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']

# 获取节点特征
feats = graph.ndata['feat']

# 获取标签
labels = graph.ndata['label']5.1 节点分类/回归 提供了训练节点分类模型的完整指南。
有关节点分类数据集的更多示例,用户可以参考以下内置数据集:

  • citationdata
  • corafulldata
  • amazoncobuydata
  • coauthordata
  • karateclubdata
  • ppidata
  • redditdata
  • sbmdata
  • sstdata
  • rdfdata
处理链接预测数据集
链接预测数据集的处理与节点分类相似,数据集中通常只有一个图。
本节以内置的数据集 KnowledgeGraphDataset 为例,同时省略了详细的数据处理代码以突出展示处理链接预测数据集的关键部分:
# 创建链接预测数据集示例
class KnowledgeGraphDataset(DGLBuiltinDataset):
    def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
        self._name = name
        self.reverse = reverse
        url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
        super(KnowledgeGraphDataset, self).__init__(name,
                                                    url=url,
                                                    raw_dir=raw_dir,
                                                    force_reload=force_reload,
                                                    verbose=verbose)

    def process(self):
        # 跳过一些处理的代码
        # === 跳过数据处理 ===

        # 划分掩码
        g.edata['train_mask'] = train_mask
        g.edata['val_mask'] = val_mask
        g.edata['test_mask'] = test_mask

        # 边类型
        g.edata['etype'] = etype

        # 节点类型
        g.ndata['ntype'] = ntype
        self._g = g

    def __getitem__(self, idx):
        assert idx == 0, "这个数据集只有一个图"
        return self._g

    def __len__(self):
        return 1如代码所示,图的 edata 存储了划分掩码。在 KnowledgeGraphDataset 源代码 中可以查看完整的代码。下面使用 ``KnowledgeGraphDataset``的子类 dgl.data.FB15k237Dataset 来做演示如何使用用于链路预测的数据集:
from dgl.data import FB15k237Dataset

# 导入数据
dataset = FB15k237Dataset()
graph = dataset[0]

# 获取训练集掩码
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
src, dst = graph.edges(train_idx)

# 获取训练集中的边类型
rel = graph.edata['etype'][train_idx]有关训练链接预测模型的完整指南,请参见 5.3 链接预测。
有关链接预测数据集的更多示例,请参考DGL的内置数据集:

  • kgdata
  • bitcoinotcdata
4. 保存和加载数据

DGL建议用户实现保存和加载数据的函数,将处理后的数据缓存在本地磁盘中。 这样在多数情况下可以帮用户节省大量的数据处理时间。DGL提供了4个函数让任务变得简单。
(1)dgl.save_graphs() 和 dgl.load_graphs(): 保存DGLGraph对象和标签到本地磁盘和从本地磁盘读取它们。
(2)dgl.data.utils.save_info() 和 dgl.data.utils.load_info(): 将数据集的有用信息(python dict对象)保存到本地磁盘和从本地磁盘读取它们。
下面的示例显示了如何保存和读取图和数据集信息的列表。
import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_info

def save(self):
    # 保存图和标签
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    save_graphs(graph_path, self.graphs, {'labels': self.labels})
    # 在Python字典里保存其他信息
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    save_info(info_path, {'num_classes': self.num_classes})

def load(self):
    # 从目录 `self.save_path` 里读取处理过的数据
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    self.graphs, label_dict = load_graphs(graph_path)
    self.labels = label_dict['labels']
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    self.num_classes = load_info(info_path)['num_classes']

def has_cache(self):
    # 检查在 `self.save_path` 里是否有处理过的数据文件
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    return os.path.exists(graph_path) and os.path.exists(info_path)请注意:有些情况下不适合保存处理过的数据。例如,在内置数据集 GDELTDataset 中, 处理过的数据比较大。所以这个时候,在 __getitem__(idx) 中处理每个数据实例是更高效的方法。
5. 使用ogb包导入OGB数据集
Open Graph Benchmark (OGB) 是一个图深度学习的基准数据集。 官方的 ogb 包提供了用于下载和处理OGB数据集到 dgl.data.DGLGraph 对象的API。本节会介绍它们的基本用法。
首先使用pip安装ogb包:
pip install ogb以下代码显示了如何为 Graph Property Prediction 任务加载数据集。
# 载入OGB的Graph Property Prediction数据集
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from dgl.dataloading import GraphDataLoader

def _collate_fn(batch):
    # 小批次是一个元组(graph, label)列表
    graphs = [e[0] for e in batch]
    g = dgl.batch(graphs)
    labels = [e[1] for e in batch]
    labels = torch.stack(labels, 0)
    return g, labels

# 载入数据集
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)加载 Node Property Prediction 数据集类似,但要注意的是这种数据集只有一个图对象。
# 载入OGB的Node Property Prediction数据集
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name='ogbn-proteins')
split_idx = dataset.get_idx_split()

# there is only one graph in Node Property Prediction datasets
# 在Node Property Prediction数据集里只有一个图
g, labels = dataset[0]
# 获取划分的标签
train_label = dataset.labels[split_idx['train']]
valid_label = dataset.labels[split_idx['valid']]
test_label = dataset.labels[split_idx['test']]每个 Link Property Prediction 数据集也只包括一个图。
# 载入OGB的Link Property Prediction数据集
from ogb.linkproppred import DglLinkPropPredDataset

dataset = DglLinkPropPredDataset(name='ogbl-ppa')
split_edge = dataset.get_edge_split()

graph = dataset[0]
print(split_edge['train'].keys())
print(split_edge['valid'].keys())
print(split_edge['test'].keys())
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver|手机版|小黑屋|天恒办公

Copyright © 2001-2013 Comsenz Inc.Template by Comsenz Inc.All Rights Reserved.

Powered by Discuz!X3.4

快速回复 返回顶部 返回列表