파이썬 네트워크 시각화 소스 코드

전에 파이썬으로 네트워크를 시각화해서 그렸던 글을 올렸었는데 소스 코드를 부탁하신 분이 계셔서 별도의 포스팅으로 올립니다.

network_viz.py로 그림 파일을 여러 장 만들고 hitcount_viz.py를 사용해서 동영상으로 인코딩했던 것 같네요. 시간이 지나고 나니 기억이 잘.. -_-;

데이터 파일은 http://archive.ics.uci.edu/ml/datasets/MSNBC.com+Anonymous+Web+Data 에서 받았습니다.

network_viz.py

import networkx as nx
import matplotlib.pyplot as plt
from numpy import *
from numpy.random import *

ma_window = 1000
ma_step = 100

def filesave(i):
	plt.figure(figsize=(16,9))
	plt.title("MSNBC Website Browsing Path (Dataset 1999)", font)

	visited = visit_raw / average(visit_raw)
	t_edge = edges_raw.diagonal()
	self = t_edge / average(t_edge)
	passed = []
	for (u,v) in G.edges():
		passed.append(edges_raw[u][v])
	passed = passed / average(passed)
	#print passed

	nx.draw_networkx_nodes(G,pos,alpha=0.7,node_color=self,node_size=visited*6000,cmap=plt.cm.Blues)
	nx.draw_networkx_edges(G,pos,alpha=0.4,edge_color=passed,width=passed*2,edge_cmap=plt.cm.Reds,arrows=False)
	nx.draw_networkx_labels(G,pos,labels=categories,font_size=20,font_family='helvetica',font_weight='bold')
	plt.text(0, 0, "edge width = # user passed, node size = # users visited, node color = self-loop ratio", font1, horizontalalignment='left', verticalalignment='bottom', transform=plt.gca().transAxes)
	plt.text(1, 0, "# Visitor: " + str(i), font2, horizontalalignment='right', verticalalignment='bottom', transform=plt.gca().transAxes)
	plt.axis('off')

	plt.savefig(".\\output\\network_viz_%07d.png" % i) # save as png

input = open("msnbc990928.seq")
data = input.read().splitlines()
input.close()
print len(data)

font = {'fontname': 'Helvetica', 'color': 'k', 'fontweight': 'bold', 'fontsize': 24}
# change font and write text (using data coordinates)
font1 = {'fontname': 'Helvetica', 'color': 'k', 'fontweight': 'bold', 'fontsize': 14}
font2 = {'fontname': 'Helvetica', 'color': 'k', 'fontweight': 'bold', 'fontsize': 30}

categories = "Frontpage News Tech Local Opinion On-air Misc Weather MSN-News Health Living Business MSN-Sports Sports Summary BBS Travel".split()
print categories

categories = dict(zip(range(17), categories))

G = nx.complete_graph(17)
G = G.to_directed()

pos=nx.spring_layout(G,iterations=100)
#print pos
pos = {
		0: array([ 0.42715358,  0.52961866]),
		1: array([ 0.6006982 ,  0.45497008]),
		2: array([ 0.39653568,  0.43281288]),
		3: array([ 0.25918795,  0.44035879]),
		4: array([ 0.31563072,  0.36506269]),
		5: array([ 0.6059773 ,  0.55317834]),
		6: array([ 0.41936401,  0.35313178]),
		7: array([ 0.44428007,  0.6528567 ]),
		8: array([ 0.55117116,  0.38336754]),
		9: array([ 0.49707884,  0.46947535]),
		10: array([ 0.54038237,  0.58690058]),
		11: array([ 0.3004154 ,  0.65209351]),
		12: array([ 0.31939352,  0.50661267]),
		13: array([ 0.24905028,  0.55743071]),
		14: array([ 0.48420206,  0.36347243]),
		15: array([ 0.36604202,  0.65038191]),
		16: array([ 0.52645572,  0.6663619 ])
	}

visit_raw = zeros(17)
edges_raw = zeros((17,17))

#for j in range(800,900):
#	for d in data[j*ma_step:j*ma_step+ma_window]:
#		last = -1
#		for i in d.split():
#			current = int(i)-1
#			if last > -1:
#				edges_raw[last][current] += 1
#			last = current
#			visit_raw[last] += 1
#	print j
#	filesave(j)
#	visit_raw = zeros(17)
#	edges_raw = zeros((17,17))

for (ind,d) in enumerate(data):
#	if (ind % 1649 == 1648) and ind > 800000 and ind < 1000000: #	if (ind % 1649 == 1648) and ind > 200000 and ind < 300000: #		print ind #		filesave(ind) #		visit_raw = zeros(17) #		edges_raw = zeros((17,17)) #		if ind > 50000: break
	last = -1
	for i in d.split():
		current = int(i)-1
		if last > -1:
			edges_raw[last][current] += 1
		last = current
		visit_raw[last] += 1

filesave(0)
print "Program finished"
#os.system("c:\\badak\\mencoder 'mf://*.png' -mf type=png:fps=20 -ovc lavc -oac copy -o hitcount_viz.avi")
#plt.show() # display

hitcount_viz.py

import os
import sys
import matplotlib.pyplot as plt
from numpy import *
from numpy.random import *

input = open("msnbc990928.seq")
data = input.read().splitlines()
input.close()

font = {'fontname': 'Helvetica', 'color': 'k', 'fontweight': 'bold', 'fontsize': 14}

categories = "frontpage news tech local opinion on-air misc weather msn-news health living business msn-sports sports summary bbs travel".split()
#categories.reverse()

print len(data)

val = zeros(17)    # the bar lengths
pos = arange(17)+5    # the bar centers on the y axis
print val

plt.xlabel('Visitors')

#for (ind, d) in enumerate(data):
#	if ind % 653 == 0:
#		print ind
#		plt.barh(pos, val, color='r', alpha=0.5, align='center')
#		plt.yticks(pos, categories)
#		plt.title("MSNBC Website Page Hit Count (Dataset 1999) | # of Visitors = " + str(ind), font)
#		plt.savefig("_hitcount_viz_%07d.png" % ind) # save as png
#	for i in d.split():
#		val[int(i)-1] += 1
#
#print val

print 'Making movie hitcount_viz.avi - this make take a while'
os.system("c:\\badak\\mencoder 'mf://*.png' -mf type=png:fps=25 -ovc lavc -oac copy -o hitcount_viz.avi")
#plt.show()