#--------------------------------------------------------------
#
#   SimChip2 - Chip simulation
#
#--------------------------------------------------------------

from __future__ import division
from numpy import ndarray, dtype, int8, uint8, float32, zeros, \
	amax, sign, int32
import profile

#--------------------------------------------------------------

cell_dtype = (int8, 7)
v_dtype = float32

Vcc = 5.0

#  0: metal (0 or 1)
#  1: metal h-connection (0 or 1)
#  2: metal v-connection (0 or 1)
#  3: silicon (-2 to +2)
#  4: silicon h-connection (0 or 1)
#  5: silicon v-connection (0 or 1)
#  6: via (0 or 1)

MET = 0
SIL = 3
VIA = 6

_MET_ = slice(0, 3)
_SIL_ = slice(3, 6)

MS = 0
HC = 1
VC = 2

#  ChipState.auto_mode
AUTO_OFF = 0
AUTO_PAUSE = 1
AUTO_RUN = 2

#--------------------------------------------------------------

class ChipStructure(object):
	#  cells   ndarray[nrows, ncols] of cell_dtype
	
	size = property(lambda self: self.cells.shape[:2])
	num_rows = property(lambda self: self.cells.shape[0])
	num_cols = property(lambda self: self.cells.shape[1])

	def __init__(self, nrows, ncols, num_pins):
		self.cells = zeros((nrows, ncols), cell_dtype)

#--------------------------------------------------------------

class ChipState(object):

	time = 0   #  Number of timesteps elapsed
	auto_mode = AUTO_OFF
	auto_end_time = 0

	def __init__(self, chip):
		nrows, ncols = chip.size
		self.Q = zeros((2, nrows, ncols))
		self.V = zeros((2, nrows, ncols))
		self.Ix = zeros((2, nrows, ncols - 1))
		self.Iy = zeros((2, nrows - 1, ncols))
		self.Iz = zeros((nrows, ncols))
		self.Vm = self.V[0]
		self.Vs = self.V[1]
		self.Z = zeros((2, nrows))
		self.Qp = [
			self.Q[0, :, :3].reshape((nrows // 3, 3, 3)),
			self.Q[0, :, -3:].reshape((nrows // 3, 3, 3))
		]
		self.Qo = [
			self.Q[0, :, 3].reshape((nrows // 3, 3)),
			self.Q[0, :, -4].reshape((nrows // 3, 3))
		]
		self.Zp = self.Z[:, ::3]
	
	def put_debug_voltage_pattern(self):
		Vm = self.Vm
		Vs = self.Vs
		nrows, ncols = shape
		for row in xrange(nrows):
			for col in xrange(ncols):
				Vm[row, col] = -5.0 + 15.0 * (row + col) / (nrows + ncols)
				Vs[row, col] = -5.0 + 15.0 * (row + (ncols - col)) / (nrows + ncols)
	
	def set_pad_impedance(self, coords, z):
		row, col = coords
		k = col >= 3
		self.Z[k, row-1:row+2] = z

	def voltage_of_pad(self, coords):
		return self.Q[0][coords] / Cpad
	
	def apply_voltage_to_pad(self, v, coords):
		q = v * Cpad
		row, col = coords
		self.Q[0, row-1:row+2, col-1:col+2] = q
	
#--------------------------------------------------------------

gamma = 7
Rmetal = 0.05
Rsilicon = Rmetal * (2 ** gamma)
L = 1.25 #2.5
C = 2.0 #1.0
Cpad = 2.0
# Must have L * C >= 2.5 for stability
Fg = 5.0

#print "Rmetal =", Rmetal
#print "Rsil1 =", Rsilicon
#print "Rsil2 =", Rsilicon / (2 ** gamma)

Linv = 1.0 / L
Cinv = 1.0 / C

stepcount = 0
dumping = 0
dump_rows = 5
profiling = 0

def logic_level_of_pad(state, coords):
	return state.voltage_of_pad(coords) > Vcc / 2

def apply_logic_level_to_pad(state, level, coords):
	state.apply_voltage_to_pad(level * Vcc, coords)

def dump(name, a):
	if dumping:
		lbl = name
		for i in xrange(min(dump_rows, a.shape[0])):
			print "%-7s" % lbl,
			lbl = ""
			for j in xrange(min(6, a.shape[1])):
				print "%8.3f" % a[i, j],
			print

def simulation_timestep(chip, num_steps):
	if profiling:
		profile.begin("timestep")
	#m, n = chip.size
	state = chip.state
	Q = state.Q
	V = state.V
	Ix = state.Ix
	Iy = state.Iy
	Iz = state.Iz
	Z = state.Z
	Qp = state.Qp
	Qo = state.Qo
	Zp = state.Zp
	Rx = ndarray(Ix.shape)
	Ry = ndarray(Iy.shape)
	Rx[0] = Rmetal
	Ry[0] = Rmetal
	Rz = Rmetal
	cells = chip.structure.cells
	M = cells[:, :, MET]
	S = cells[:, :, SIL]
	Ds = (abs(S).astype(int32) ** gamma) * sign(S)
	K = cells[:, :, MET:SIL+1:3].transpose((2, 0, 1)) <> 0
	Kx = cells[:, :-1, MET+HC:SIL+HC+1:3].transpose((2, 0, 1))
	Ky = cells[:-1, :, MET+VC:SIL+VC+1:3].transpose((2, 0, 1))
	Kz = cells[:, :, VIA]
	Y0 = (1 - Z[0]) * Kx[0, :, 2]
	Y1 = (1 - Z[1]) * Kx[0, :, -3]
	#dump("K[0]", K[0])
	#dump("K[1]", K[1])
	#dump("Kz", Kz)
	n = num_steps
	while n:
		if dumping:
			print "t =", state.time + num_steps - n
		n -= 1
		Q[0, :, 3] = Y0 * Q[0, :, 2] + (1 - Y0) * Q[0, :, 3]
		Q[0, :, -4] = Y1 * Q[0, :, -3] + (1 - Y1) * Q[0, :, -4]
		dump("Q[0]", Q[0])
		#dump("Q[1]", Q[1])
		#V[:] = Q * Cinv
		V[:, :3] = Q[:, :3] / Cpad
		V[:, 3:-3] = Q[:, 3:-3] * Cinv
		V[:, -3:] = Q[:, -3:] / Cpad
		#dump("V[0]", V[0])
		#dump("V[1]", V[1])
		#dump("Ix[0]", Ix[0])
		#dump("Ix[1]", Ix[1])
		#dump("Iy[1]", Iy[1])
		#dump("Iz", Iz)
		dVx = V[:, :, :-1] - V[:, :, 1:]
		dVy = V[:, :-1, :] - V[:, 1:, :]
		dVz = V[0] - V[1]
		G = dVz * (~Kz & M & (S <> 0)) * Fg
		D = Ds - G
		#dump("G", G)
		#dump("D", D)
		Rs = Rsilicon / (abs(D) + 1e-6)
		Rx[1] = 0.5 * (Rs[:, 1:] + Rs[:, :-1])
		Ry[1] = 0.5 * (Rs[1:, :] + Rs[:-1, :])
		#dump("Rx[0]", Rx[0])
		#dump("Rx[1]", Rx[1])
		#dump("Ry[0]", Ry[0])
		#dump("Ry[1]", Ry[1])
		#dump("Rx[0]", Rx[0])
		Dx0 = D[:, :-1]
		Dx1 = D[:, 1:]
		Dy0 = D[:-1, :]
		Dy1 = D[1:, :]
		Jx = ((Dx0 > 0) & (Dx1 < 0)) + (-1 * ((Dx0 < 0) & (Dx1 > 0)))
		Jy = ((Dy0 > 0) & (Dy1 < 0)) + (-1 * ((Dy0 < 0) & (Dy1 > 0)))
		#dump("S", S)
		#dump("Dx0", Dx0)
		#dump("Dx1", Dx1)
		#dump("Jx", Jx)
		#dump("Jy", Jy)
		#Bx = Jx * dVx[1] >= 0.0
		#By = Jy * dVy[1] >= 0.0
		#dump("By", By)
		dIx = (dVx - Rx * Ix) / (L + Rx)
		dIy = (dVy - Ry * Iy) / (L + Ry)
		dIz = (dVz - Rz * Iz) / (L + Rz)
		#dump("dIx[0]", dIx[0])
		#dump("dIx[1]", dIx[1])
		#dump("dIy[1]", dIy[1])
		#dump("dIz", dIz)
		Ix += dIx
		Iy += dIy
		Iz += dIz
		#dump("Ix[0]#2", Ix[0])
		#dump("Iz#2", Iz)
		Bx = Jx * Ix[1] >= 0.0
		By = Jy * Iy[1] >= 0.0
		Ix[0] *= Kx[0]
		Ix[1] *= Kx[1] & Bx
		Iy[0] *= Ky[0]
		Iy[1] *= Ky[1] & By
		Iy *= Ky
		Iz *= Kz
		#dump("Ix[0]#3", Ix[0])
		#dump("Iz#3", Iz)
		Ix[0, :, 2] = 0.0  # *= (1 - Z[0])
		Ix[0, :, -3] = 0.0  # *= (1 - Z[1])
		Q[:, :, :-1] -= Ix
		Q[:, :, 1:] += Ix
		Q[:, :-1, :] -= Iy
		Q[:, 1:, :] += Iy
		Q[0] -= Iz
		Q[1] += Iz
		#dump("Q[0]#2", Q[0])
		#dump("Q[1]#2", Q[1])
		Q *= K
		#dump("Q[0]#3", Q[0])
		#dump("Q[1]#3", Q[1])
	npads = chip.size[0] // 3
	Kp = [Kx[0, :, i].reshape((npads, 3)) for i in (2, -3)]
	for i in (0, 1):
		for j in xrange(npads):
			if Zp[i][j]:
				Qp[i][j] = amax(Qo[i][j] * Kp[i][j])
	dump("Qp", Qp[0][0])
	dump("Q[0]*", Q[0])
	if state.auto_mode == AUTO_RUN:
		state.time += num_steps
	if profiling:
		global stepcount
		stepcount += num_steps
		profile.end("timestep")
		print "Step", stepcount
