1. torch_geometric.data Format
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
GNN 모델을 구현할 때 주로 torch_geometric 패키지를 사용한다.
Data는 정해진 포맷이 있는데, 다음과 같다.
- x (Tensor) : Node feature matrix이다. shape : [num_nodes, num_node_features]
- edge_index (Long Tensor) : Graph connectivity format. shape : [2, num_edges]
edge_index 가 어떤 포맷인지 감이 안올 수 있다.
만약 (1,2) (1,3) (2,3) node pair 가 연결되어 있다고 할 때,
source node [1,1,2]
target node [2,3,3] 으로 저장하는 것이다.
2. list --> tensor
x (Tensor) : node feature matrix format 에 맞추기 위해
준비한 double list를 tensor로 바꾼다.
python 의 자료형 : list / pytorch 의 자료형 : tensor
>>> import torch
>>> a = [[1,2,3],[1,1,1]]
>>> a = torch.tensor(a)
>>> print(a)
tensor([[1,2,3],
[1,1,1]])
tensor --> list
>>> a = a.tolist()
3. scipy sparse matrix --> edge_index
직접 생성한 그래프 데이터로 모델링해보자.
GCN 모델은 다음과 같다.
class GCN(torch.nn.Module):
def __init__(self, input_features, hidden_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(input_features, hidden_features)
self.conv2 = GCNConv(hidden_features, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
out = F.log_softmax(x,1)
return x, out
여기 data는 x와 edge_index 를 의미한다.
그래프는 다양한 방법으로 생성할 수 있다. 가장 기본적인 Erdos Graph, G(n, p) graph를 생성한다.
n은 노드의 수, p는 노드간 edge가 생성될 확률을 의미한다.
import torch
from torch_geometric.data import Data
import networkx as nx
G = nx.fast_gnp_random_graph(10, 0.5)
adj_matrix = nx.adjacency_matrix(G)
networkx 패키지에는 그래프 생성, 분석 관련 모듈이 웬만한 건 다 있다.
G는 노드 10개, 링크 생성확률은 0.5인 그래프이다.
node_feature는 따로 준비했으므로, 여기선 edge_index를 뽑아보자.
- nx.fast_gnp_random_graph(n,p) : return SciPy sparse matrix
SciPy sparse matrix 가 뭔가 출력해보았다.
(source node, target node) edge weight
의 포맷을 가진다.
이를 edge_index 텐서로 바꿔주면 쉽게 데이터를 적용할 수 있다.
scipy sparse matrix --> edge_index 함수도 torch_geometric.utils 에서 제공하고 있다.
from torch_geometric.utils import from_scipy_sparse_matrix
Z = from_scipy_sparse_matrix(adj_matrix)
아주 간단하다. 결과는 다음과 같다.
tensor([[source node list],[target node list]]), tensor([edge weight list]) 의 format을 가진다.
따라서 edge_index 는 Z [0]로 사용 가능하다.
'python, pyTorch' 카테고리의 다른 글
입력받기 input() 여러 개, 리스트 (0) | 2021.11.18 |
---|---|
데이터 전처리를 해보자. (0) | 2021.10.21 |
python) csv 파일 읽어오기/수정하기/쓰기 (0) | 2021.09.07 |
Pycharm 과 Anaconda 연동하기 (0) | 2021.09.02 |
그래프 예쁘게 그리기 2탄 : networkx.draw 함수 (0) | 2021.08.11 |