python, pyTorch

torch_geometric.data format 전환하기

zooyeonii 2021. 9. 8. 02:29

1. torch_geometric.data Format

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

GNN 모델을 구현할 때 주로 torch_geometric 패키지를 사용한다. 
Data는 정해진 포맷이 있는데, 다음과 같다. 

  • x (Tensor) : Node feature matrix이다. shape : [num_nodes, num_node_features]
  • edge_index (Long Tensor) : Graph connectivity format. shape : [2, num_edges] 

edge_index 가 어떤 포맷인지 감이 안올 수 있다. 
만약 (1,2) (1,3) (2,3) node pair 가 연결되어 있다고 할 때,
source node [1,1,2]
target node [2,3,3] 으로 저장하는 것이다. 

edge_index 예시

2. list --> tensor

x (Tensor) : node feature matrix format 에 맞추기 위해 
준비한 double list를 tensor로 바꾼다. 

python 의 자료형 : list / pytorch 의 자료형 : tensor 

>>> import torch
>>> a = [[1,2,3],[1,1,1]]
>>> a = torch.tensor(a)
>>> print(a) 
tensor([[1,2,3],
	[1,1,1]])

tensor --> list

>>> a = a.tolist()

3. scipy sparse matrix --> edge_index

직접 생성한 그래프 데이터로 모델링해보자. 

GCN 모델은 다음과 같다. 

class GCN(torch.nn.Module):
    def __init__(self, input_features, hidden_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_features, hidden_features)
        self.conv2 = GCNConv(hidden_features, num_classes)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        out = F.log_softmax(x,1) 
        return x, out

여기 data는 x와 edge_index 를 의미한다. 

그래프는 다양한 방법으로 생성할 수 있다. 가장 기본적인 Erdos Graph, G(n, p) graph를 생성한다. 
n은 노드의 수, p는 노드간 edge가 생성될 확률을 의미한다. 

import torch
from torch_geometric.data import Data
import networkx as nx

G = nx.fast_gnp_random_graph(10, 0.5)
adj_matrix = nx.adjacency_matrix(G)

networkx 패키지에는 그래프 생성, 분석 관련 모듈이 웬만한 건 다 있다. 
G는 노드 10개, 링크 생성확률은 0.5인 그래프이다. 
node_feature는 따로 준비했으므로, 여기선 edge_index를 뽑아보자. 

  • nx.fast_gnp_random_graph(n,p) : return SciPy sparse matrix 

SciPy sparse matrix 가 뭔가 출력해보았다. 

adjacency matrix 가 scipy sparse matrix 로 저장됨. 

(source node, target node)       edge weight 
의 포맷을 가진다. 

이렇게 바꾸는 것도 가능함!

이를 edge_index 텐서로 바꿔주면 쉽게 데이터를 적용할 수 있다. 
scipy sparse matrix --> edge_index 함수도 torch_geometric.utils 에서 제공하고 있다. 

from torch_geometric.utils import from_scipy_sparse_matrix
Z = from_scipy_sparse_matrix(adj_matrix)

아주 간단하다. 결과는 다음과 같다. 

tensor([[source node list],[target node list]]), tensor([edge weight list]) 의 format을 가진다. 
따라서 edge_index 는 Z [0]로 사용 가능하다.