python, pyTorch

pytorch : txt파일 전처리

zooyeonii 2021. 8. 11. 17:32

readlines 함수 사용하기

f = open("파일경로/파일이름.txt", 'r')
lines = f.readlines() 
for line in lines: 
	print(line) 

f.close()

사용하는 txt 파일 
0,1
1,2
1,3
2,3
... 그래프의 edge를 기록한 파일. 

Data pre-processing : GCN 모델에 사용할 데이터이기 때문에, torch_geometric.data format에 맞춰줄 것이다.

  • x(Tensor) : Node feature matrix. shape : [num_nodes, num_node_features]
  • edge_index(LongTensor) : Graph connectivity format with shape : [2, num_edges] 

위 link.txt 파일을 edge_index 로 옮길 때, format을 바꿔주어야 한다. 

현재 저장되어있는 form

 

현재 tensor shape [num_edges, 2] 이기 때문에

 

아래 그림처럼 저장해주려고 한다. 

edge_index format

 

def load_data(path, node_feat):
    f = open("links.txt", 'r')
    lines = f.readlines()
    sn, tn = [], [] 
    # to make edge_index format. tensor_shape : [2, num_edges]
    for line in lines:
        s, t = line.split(",")
        sn.append(int(s))
        tn.append(int(t))
    edge_index = torch.tensor([sn, tn], dtype=torch.long)
    # node_feature matrix format; tensor_shape : [num_nodes, num_node_features]
    data = Data(x=node_feat, edge_index=edge_index)
    
    f.close()
    return data