#--------------------------------------------------------------
#
#   SimChip2 - Chip simulation
#
#--------------------------------------------------------------

from __future__ import division
from base64 import b64encode, b64decode
import numpy
from numpy import ndarray, dtype, int8, uint8, float32, zeros, \
	amax, sign, int32, array, bitwise_or, flatnonzero, rot90
from GUI.Geometry import empty_rect, sect_rect, offset_rect
import profile

#--------------------------------------------------------------

cell_dtype = (int8, 7)
v_dtype = float32

Vcc = 5.0

#  0: metal (0 or 1)
#  1: metal h-connection (0 or 1)
#  2: metal v-connection (0 or 1)
#  3: silicon (-2 to +2)
#  4: silicon h-connection (0 or 1)
#  5: silicon v-connection (0 or 1)
#  6: via (0 or 1)

MET = 0   #  Metal layer
SIL = 3   #  Silicon layer
VIA = 6   #  Via layer

_MET_ = slice(0, 3)    #  Metal and its connections
_SIL_ = slice(3, 6)    #  Silicon and its connections

#  Offsets from one layer to another
MS = 0   #  Base layer (metal or silicon)
HC = 1   #  Offset to horizontal connections layer
VC = 2   #  Offset to vertical connections layer

_MS_ = slice(0, 4, 3)   #  Metal and silicon
_MSV_ = slice(0, 7, 3)  #  Metal, silicon and vias
_HC_ = slice(1, 5, 3)   #  Horizontal connections
_VC_ = slice(2, 6, 3)   #  Vertical connections

#  ChipState.auto_mode
AUTO_OFF = 0
AUTO_PAUSE = 1
AUTO_RUN = 2

#--------------------------------------------------------------

class CellArray(object):
	#  cells   ndarray[nrows, ncols] of cell_dtype
	
	packed_dtype = 'h'

	size = property(lambda self: self.cells.shape[:2])
	num_rows = property(lambda self: self.cells.shape[0])
	num_cols = property(lambda self: self.cells.shape[1])

	bounds = property(lambda self: self.get_bounds())
	
	def __init__(self, nrows, ncols):
		self.cells = zeros((nrows, ncols), cell_dtype)
	
	def get_cell(self, coords):
		return array(self.cells[coords])
	
	def get_rect(self, rect):
		t, l, b, r = rect
		return array(self.cells[t:b, l:r])
	
	def set_rect(self, rect, cells):
		t, l, b, r = rect
		self.cells[t:b, l:r] = cells

	def to_string(self):
		cells = self.cells.astype(self.packed_dtype)
		nr, nc = self.size
		packed = (
			   cells[..., MET     ]
			| (cells[..., MET + HC] << 1)
			| (cells[..., MET + VC] << 2)
			|((cells[..., SIL] + 2) << 3)
			| (cells[..., SIL + HC] << 6)
			| (cells[..., SIL + VC] << 7)
			| (cells[..., VIA     ] << 8)
		)
		data = b64encode(packed.tostring())
		return "%s,%s,%s,%s,%s" % (self.magic, self.packed_dtype, nr-1, nc-1, data)
	
#--------------------------------------------------------------

class Snippet(CellArray):
	#  Snippets have an extra row at the top and an extra column
	#  on the left side containing connections (the "border").
	#  The border is included in the size but not the bounds.

	magic = "SChS"

	position = (0, 0)    #  Position of top left corner excluding border
	render_buffer = None

	def __init__(self, nrows, ncols):
		CellArray.__init__(self, nrows + 1, ncols + 1)
	
	def __getstate__(self):
		d = self.__dict__.copy()
		d['render_buffer'] = None
		return d
	
	def get_bounds(self):
		t, l = self.position
		nr, nc = self.size
		return (t, l, t + nr - 1, l + nc - 1)
	
	@classmethod
	def is_valid_string(cls, data):
		if not isinstance(data, basestring):
			return False
		try:
			cls.unpack_header(data)
			return True
		except ValueError:
			return False
	
	@classmethod
	def unpack_header(cls, string):
		magic, dtype, nrows, ncols, data = string.split(",")
		if magic <> cls.magic:
			raise ValueError("Wrong magic for snippet")
		nrows = int(nrows)
		ncols = int(ncols)
		return dtype, nrows, ncols, data
	
	@classmethod
	def from_string(cls, string):
		dtype, nrows, ncols, b64 = cls.unpack_header(string)
		data = b64decode(b64)
		print "Snippet.from_string: nrows =", nrows, "ncols =", ncols, "len(data) =", len(data) ###
		packed = numpy.fromstring(data, dtype).reshape((nrows + 1, ncols + 1))
		self = cls(nrows, ncols)
		cells = self.cells
		cells[..., MET     ] =   packed       & 1
		cells[..., MET + HC] =  (packed >> 1) & 1
		cells[..., MET + VC] =  (packed >> 2) & 1
		cells[..., SIL     ] = ((packed >> 3) & 7) - 2
		cells[..., SIL + HC] =  (packed >> 7) & 1
		cells[..., SIL + VC] =  (packed >> 8) & 1
		cells[..., VIA     ] =  (packed >> 9) & 1
		return self
	
	def center_in_chip(self, chip):
		nr, nc = self.size
		cnr, cnc = chip.size
		self.position = (cnr - nr) // 2, (cnc - nc) // 2
	
	def center_on_coords(self, coords):
		#print "Snippet.center_on_coords:", coords ###
		row, col = coords
		nr, nc = self.size
		self.position = row - (nr - 1) // 2, col - (nc - 1) // 2
	
	def center_in_rect(self, rect):
		#print "Snippet.center_in_rect:", rect ###
		t, l, b, r = rect
		nr, nc = self.size
		self.position = (t + b - (nr - 1)) // 2, (l + r - (nc - 1)) // 2
	
	def clone(self):
		nrows, ncols = self.size
		snip = Snippet(nrows - 1, ncols - 1)
		snip.position = self.position
		return snip
	
	def rotated_clone(self):
		nrows, ncols = self.size
		y, x = self.position
		cx = x + ncols // 2
		cy = y + nrows // 2
		snip = Snippet(ncols - 1, nrows - 1)
		snip.position = (cy - ncols // 2, cx - nrows // 2)
		return snip
	
	def flip_horizontal(self):
		snip = self.clone()
		snip.cells[:, 1:, _MSV_] = self.cells[:, -1:0:-1, _MSV_]
		snip.cells[:, 1:, _VC_] = self.cells[:, -1:0:-1, _VC_]
		snip.cells[:, :, _HC_] = self.cells[:, ::-1, _HC_]
		return snip
	
	def flip_vertical(self):
		snip = self.clone()
		snip.cells[1:, :, _MSV_] = self.cells[-1:0:-1, :, _MSV_]
		snip.cells[1:, :, _HC_] = self.cells[-1:0:-1, :, _HC_]
		snip.cells[:, :, _VC_] = self.cells[::-1, :, _VC_]
		return snip
	
	def rotate(self, k):
		snip = self.rotated_clone()
		snip.cells[1:, 1:, _MSV_] = rot90(self.cells[1:, 1:, _MSV_], k)
		snip.cells[1:, :, _HC_] = rot90(self.cells[:, 1:, _VC_], k)
		snip.cells[:, 1:, _VC_] = rot90(self.cells[1:, :, _HC_], k)
		return snip
	
	def invalidate(self):
		self.render_buffer = None
	
#--------------------------------------------------------------

class ChipStructure(CellArray):

	magic = "SChC"
	pad_height = 3

	def __init__(self, nrows, ncols, pad_height = 3):
		CellArray.__init__(self, nrows, ncols)
		self.pad_height = pad_height
	
	def get_bounds(self):
		nr, nc = self.size
		return (0, 0, nr, nc)

	def copy(self, r):
		t, l, b, r = r
		nr = min(b - t, self.num_rows)
		nc = min(r - l, self.num_cols)
		snip = Snippet(nr, nc)
		snip.cells[1:nr+1, 1:nc+1] = self.cells[t:t+nr, l:l+nc]
		if l > 0:
			snip.cells[1:nr+1, 0, _HC_] = self.cells[t:t+nr, l-1, _HC_]
		if t > 0:
			snip.cells[0, 1:nc+1, _VC_] = self.cells[t-1, l:l+nc, _VC_]
		snip.position = (t, l)
		return snip
	
	def paste(self, snip):
		chip_rows, chip_cols = self.size
		chip_rect = (0, 3, chip_rows, chip_cols - 3)
		paste_rect = sect_rect(snip.bounds, chip_rect)
		if not empty_rect(paste_rect):
			t, l, b, r = paste_rect
			pt, pl = snip.position
			snip_rect = offset_rect(paste_rect, (1 - pt, 1 - pl))
			st, sl, sb, sr = snip_rect
			if t > 0:
				t -= 1
				st -= 1
			if l > 0:
				l -= 1
				sl -= 1
			cells = self.cells
			old_cells = cells[t:b, l:r]
			new_cells = snip.cells[st:sb, sl:sr]
			mask = (new_cells == 0) * 0xff
			old_cells[...] = (old_cells & mask) | new_cells
			def fix_hc(x):
				if x < chip_cols - 1:
					mask = (cells[t:b, x, _MS_] <> 0) & (cells[t:b, x+1, _MS_] <> 0)
				else:
					mask = 0
				cells[t:b, x, _HC_] &= mask
			def fix_vc(y):
				if y < chip_rows - 1:
					mask = (cells[y, l:r, _MS_] <> 0) & (cells[y+1, l:r, _MS_] <> 0)
				else:
					mask = 0
				cells[y, l:r, _VC_] &= mask
			fix_hc(l)
			fix_hc(r-1)
			fix_vc(t)
			fix_vc(b-1)
		return paste_rect
	
	def clear(self, r):
		t, l, b, r = r
		self.cells[t:b, l:r] = 0
		if t > 0:
			self.cells[t-1, l:r, _VC_] = 0
		if l > 0:
			self.cells[t:b, l-1, _HC_] = 0

	def measure_silicon_area(self):
		def extent(a):
			nz = flatnonzero(a)
			if len(nz):
				return max(nz) - min(nz) + 1
			else:
				return 0
		sil = self.cells[..., SIL]
		vproj = bitwise_or.reduce(sil, 0)
		hproj = bitwise_or.reduce(sil, 1)
		#print "ChipStructure.measure_silicon_area:" ###
		#print "... hproj =", hproj ###
		#print "... vproj =", vproj ###
		return extent(hproj) * extent(vproj)

#--------------------------------------------------------------

class ChipState(object):

	time = 0   #  Number of timesteps elapsed
	auto_mode = AUTO_OFF
	auto_end_time = 0

	def __init__(self, chip):
		nrows, ncols = chip.size
		hp = chip.pad_height
		self.Q = zeros((2, nrows, ncols))
		self.V = zeros((2, nrows, ncols))
		self.Ix = zeros((2, nrows, ncols - 1))
		self.Iy = zeros((2, nrows - 1, ncols))
		self.Iz = zeros((nrows, ncols))
		self.Vm = self.V[0]
		self.Vs = self.V[1]
		self.Z = zeros((2, nrows))
		self.Qp = [
			self.Q[0, :, :3].reshape((nrows // hp, hp, 3)),
			self.Q[0, :, -3:].reshape((nrows // hp, hp, 3))
		]
		self.Qo = [
			self.Q[0, :, 3].reshape((nrows // hp, hp)),
			self.Q[0, :, -4].reshape((nrows // hp, hp))
		]
		self.Zp = self.Z[:, ::hp]
		self.W = zeros((2, nrows))
	
	def put_debug_voltage_pattern(self):
		Vm = self.Vm
		Vs = self.Vs
		nrows, ncols = shape
		for row in xrange(nrows):
			for col in xrange(ncols):
				Vm[row, col] = -5.0 + 15.0 * (row + col) / (nrows + ncols)
				Vs[row, col] = -5.0 + 15.0 * (row + (ncols - col)) / (nrows + ncols)
	
	def set_pad_impedance(self, coords, z):
		row, col = coords
		k = col >= 3
		self.Z[k, row-1:row+2] = z

	def voltage_of_pad(self, coords):
		return self.Q[0][coords] / Cpad
	
	def apply_voltage_to_pad(self, v, coords):
		q = v * Cpad
		row, col = coords
		self.Q[0, row-1:row+2, col-1:col+2] = q
	
#--------------------------------------------------------------

gamma = 7
Rmetal = 0.05
Rsilicon = Rmetal * (2 ** gamma)
L = 1.25 #2.5
C = 2.0 #1.0
Cpad = 2.0
# Must have L * C >= 2.5 for stability
Fg = 5.0

#print "Rmetal =", Rmetal
#print "Rsil1 =", Rsilicon
#print "Rsil2 =", Rsilicon / (2 ** gamma)

Linv = 1.0 / L
Cinv = 1.0 / C

stepcount = 0
dumping = 0
dump_rows = 5
profiling = 0

def logic_level_of_pad(state, coords):
	return state.voltage_of_pad(coords) > Vcc / 2

def apply_logic_level_to_pad(state, level, coords):
	state.apply_voltage_to_pad(level * Vcc, coords)

def dump(name, a):
	if dumping:
		lbl = name
		for i in xrange(min(dump_rows, a.shape[0])):
			print "%-7s" % lbl,
			lbl = ""
			for j in xrange(min(6, a.shape[1])):
				print "%8.3f" % a[i, j],
			print

def simulation_timestep(chip, num_steps):
	if profiling:
		profile.begin("timestep")
	#m, n = chip.size
	state = chip.state
	Q = state.Q
	V = state.V
	Ix = state.Ix
	Iy = state.Iy
	Iz = state.Iz
	Z = state.Z
	Qp = state.Qp
	Qo = state.Qo
	Zp = state.Zp
	W = state.W
	Rx = ndarray(Ix.shape)
	Ry = ndarray(Iy.shape)
	Rx[0] = Rmetal
	Ry[0] = Rmetal
	Rz = Rmetal
	cells = chip.structure.cells
	M = cells[:, :, MET]
	S = cells[:, :, SIL]
	Ds = (abs(S).astype(int32) ** gamma) * sign(S)
	K = cells[:, :, MET:SIL+1:3].transpose((2, 0, 1)) <> 0
	Kx = cells[:, :-1, MET+HC:SIL+HC+1:3].transpose((2, 0, 1))
	Ky = cells[:-1, :, MET+VC:SIL+VC+1:3].transpose((2, 0, 1))
	Kz = cells[:, :, VIA]
	Y0 = (1 - Z[0]) * Kx[0, :, 2]
	Y1 = (1 - Z[1]) * Kx[0, :, -3]
	#dump("K[0]", K[0])
	#dump("K[1]", K[1])
	#dump("Kz", Kz)
	n = num_steps
	while n:
		if dumping:
			print "t =", state.time + num_steps - n
		n -= 1
		dump("Q[0]", Q[0])
		dump("Q'[0]", Q[0, :, -6:])
		W0 = Y0 * (Q[0, :, 2] - Q[0, :, 3])
		W1 = Y1 * (Q[0, :, -3] - Q[0, :, -4])
		W[0] += W0
		W[1] += W1
		Q[0, :, 3] += W0
		Q[0, :, -4] += W1
		dump("W", W.transpose())
		dump("Q2[0]", Q[0])
		dump("Q2'[0]", Q[0, :, -6:])
		#Q[0, :, 3] = Y0 * Q[0, :, 2] + (1 - Y0) * Q[0, :, 3]
		#Q[0, :, -4] = Y1 * Q[0, :, -3] + (1 - Y1) * Q[0, :, -4]
		#dump("Q[0]", Q[0])
		#dump("Q[1]", Q[1])
		#V[:] = Q * Cinv
		V[:, :3] = Q[:, :3] / Cpad
		V[:, 3:-3] = Q[:, 3:-3] * Cinv
		V[:, -3:] = Q[:, -3:] / Cpad
		#dump("V[0]", V[0])
		#dump("V[1]", V[1])
		#dump("Ix[0]", Ix[0])
		#dump("Ix[1]", Ix[1])
		#dump("Iy[1]", Iy[1])
		#dump("Iz", Iz)
		dVx = V[:, :, :-1] - V[:, :, 1:]
		dVy = V[:, :-1, :] - V[:, 1:, :]
		dVz = V[0] - V[1]
		G = dVz * (~Kz & M & (S <> 0)) * Fg
		D = Ds - G
		#dump("G", G)
		#dump("D", D)
		Rs = Rsilicon / (abs(D) + 1e-6)
		Rx[1] = 0.5 * (Rs[:, 1:] + Rs[:, :-1])
		Ry[1] = 0.5 * (Rs[1:, :] + Rs[:-1, :])
		#dump("Rx[0]", Rx[0])
		#dump("Rx[1]", Rx[1])
		#dump("Ry[0]", Ry[0])
		#dump("Ry[1]", Ry[1])
		#dump("Rx[0]", Rx[0])
		Dx0 = D[:, :-1]
		Dx1 = D[:, 1:]
		Dy0 = D[:-1, :]
		Dy1 = D[1:, :]
		Jx = ((Dx0 > 0) & (Dx1 < 0)) + (-1 * ((Dx0 < 0) & (Dx1 > 0)))
		Jy = ((Dy0 > 0) & (Dy1 < 0)) + (-1 * ((Dy0 < 0) & (Dy1 > 0)))
		#dump("S", S)
		#dump("Dx0", Dx0)
		#dump("Dx1", Dx1)
		#dump("Jx", Jx)
		#dump("Jy", Jy)
		#Bx = Jx * dVx[1] >= 0.0
		#By = Jy * dVy[1] >= 0.0
		#dump("By", By)
		dIx = (dVx - Rx * Ix) / (L + Rx)
		dIy = (dVy - Ry * Iy) / (L + Ry)
		dIz = (dVz - Rz * Iz) / (L + Rz)
		#dump("dIx[0]", dIx[0])
		#dump("dIx[1]", dIx[1])
		#dump("dIy[1]", dIy[1])
		#dump("dIz", dIz)
		Ix += dIx
		Iy += dIy
		Iz += dIz
		#dump("Ix[0]#2", Ix[0])
		#dump("Iz#2", Iz)
		Bx = Jx * Ix[1] >= 0.0
		By = Jy * Iy[1] >= 0.0
		Ix[0] *= Kx[0]
		Ix[1] *= Kx[1] & Bx
		Iy[0] *= Ky[0]
		Iy[1] *= Ky[1] & By
		Iy *= Ky
		Iz *= Kz
		#dump("Ix[0]#3", Ix[0])
		#dump("Iz#3", Iz)
		Ix[0, :, 2] = 0.0  # *= (1 - Z[0])
		Ix[0, :, -3] = 0.0  # *= (1 - Z[1])
		Q[:, :, :-1] -= Ix
		Q[:, :, 1:] += Ix
		Q[:, :-1, :] -= Iy
		Q[:, 1:, :] += Iy
		Q[0] -= Iz
		Q[1] += Iz
		#dump("Q[0]#2", Q[0])
		#dump("Q[1]#2", Q[1])
		Q *= K
		#dump("Q[0]#3", Q[0])
		#dump("Q[1]#3", Q[1])
	hp = chip.pad_height
	npads = chip.size[0] // hp
	Kp = [Kx[0, :, i].reshape((npads, hp)) for i in (2, -3)]
	for i in (0, 1):
		for j in xrange(npads):
			if Zp[i][j]:
				Qp[i][j] = amax(Qo[i][j] * Kp[i][j])
	#dump("Qp", Qp[0][0])
	#dump("Q[0]*", Q[0])
	if state.auto_mode == AUTO_RUN:
		state.time += num_steps
	if profiling:
		global stepcount
		stepcount += num_steps
		profile.end("timestep")
		print "Step", stepcount
