#--------------------------------------------------------------
#
#   SimChip2 - Rendering buffer
#
#--------------------------------------------------------------

from weakref import WeakKeyDictionary
from numpy import ndarray, uint8, asarray, array, newaxis, interp
from GUI.Numerical import image_from_ndarray
from resources import get_pil_image
from simulation import _MET_, _SIL_, MET, SIL, VIA, MS, HC, VC, Vcc
import profile

rgba_dtype = (uint8, 4)
mask_dtype = uint8

def get_sym(name, mode = None):
	im = get_pil_image("sym-" + name + ".png")
	if mode:
		im = im.convert(mode)
	return asarray(im)

def get_symarray(name, **kwds):
	a = get_sym(name, **kwds)
	cellsize = a.shape[0] // 4
	return a.reshape((4, cellsize, 5, cellsize) + a.shape[2:])

sym_via = get_sym("via")
sym_metal = get_symarray("metal")
sym_silicon = [get_symarray(n) for n in ("p", "p+", "n+", "n")]
sym_mask_metal = get_symarray("mask-metal", mode = 'L')
sym_mask_silicon = get_symarray("mask-silicon", mode = 'L')

#--------------------------------------------------------------

cache = WeakKeyDictionary()

def get_render_buffer(chip):
	buf = cache.get(chip)
	if not buf:
		buf = ChipRenderBuffer(chip.size)
		cache[chip] = buf
	buf.update(chip)
	return buf

def invalidate_render_buffer(chip, coords = None):
	buf = cache.get(chip)
	if buf:
		if coords is not None:
			buf.invalidate_cell(coords)
		else:
			buf.invalidate()

#--------------------------------------------------------------

vrgb_x = array([-5.0,    0.0,      5.0,    10.0])
vrgb_r = array([0.0,     0.0,    255.0,   255.0])
vrgb_g = array([0.0,     0.0,      0.0,   255.0])
vrgb_b = array([0.0,   255.0,      0.0,     0.0])

def voltage_to_rgb(v, rgb):
	r = interp(v, vrgb_x, vrgb_r).astype(uint8)
	g = interp(v, vrgb_x, vrgb_g).astype(uint8)
	b = interp(v, vrgb_x, vrgb_b).astype(uint8)
	rgb[..., 0, :, :] = r[..., newaxis, newaxis]
	rgb[..., 1, :, :] = g[..., newaxis, newaxis]
	rgb[..., 2, :, :] = b[..., newaxis, newaxis]

#--------------------------------------------------------------

class ChipRenderBuffer(object):

	def __init__(self, size, cellsize = 16):
		self.size = size
		self.cellsize = cellsize
		nrows, ncols = size
		halfsize = cellsize // 2
		shape = (nrows, cellsize, ncols, cellsize)
		qshape = (nrows, 2, halfsize, ncols, 2, halfsize)
		bycell = (0, 2, 1, 3)
		rgb_bycell = (0, 2, 4, 1, 3)
		self.via = ndarray(shape, rgba_dtype)
		self.metal = ndarray(qshape, rgba_dtype)
		self.silicon = ndarray(qshape, rgba_dtype)
		#self.mask_metal = ndarray(qshape, mask_dtype)
		#self.mask_silicon = ndarray(qshape, mask_dtype)
		#self.mask_metal_bycell = self.mask_metal.reshape(shape).transpose(bycell)
		#self.mask_silicon_bycell = self.mask_silicon.reshape(shape).transpose(bycell)
		self.v_metal = ndarray(shape, rgba_dtype)
		self.v_silicon = ndarray(shape, rgba_dtype)
		self.v_metal_rgb_bycell = self.v_metal[..., :3].transpose(rgb_bycell)
		self.v_silicon_rgb_bycell = self.v_silicon[..., :3].transpose(rgb_bycell)
		self.mask_metal = self.v_metal[..., 3].reshape(qshape)
		self.mask_silicon = self.v_silicon[..., 3].reshape(qshape)
		self.mask_metal_bycell = self.mask_metal.reshape(shape).transpose(bycell)
		self.mask_silicon_bycell = self.mask_silicon.reshape(shape).transpose(bycell)
		self.invalid_cells = set()
		self.invalid = True
	
	def invalidate(self):
		self.invalid = True
	
	def invalidate_cell(self, coords):
		self.invalid_cells.add(coords)

	def update(self, chip):
		if self.invalid:
			nrows, ncols = self.size
			self.update_range(chip, xrange(nrows), xrange(ncols))
		else:
			for row, col in self.invalid_cells:
				self.update_range(chip, [row], [col])
		self.update_voltages(chip)
		self.update_images()
		self.invalid_cells.clear()
		self.invalid = False
	
	def update_range(self, chip, row_range, col_range):
		cells = chip.structure.cells
		chip_via = cells[..., VIA]
		chip_metal = cells[..., _MET_]
		chip_silicon = cells[..., _SIL_]
		buf_via = self.via
		buf_metal = self.metal
		buf_silicon = self.silicon
		buf_mask_metal = self.mask_metal
		buf_mask_silicon = self.mask_silicon
		nrows, ncols = chip.size
		maxrow = nrows - 1
		maxcol = ncols - 1

		def draw_cell(layer, buf, buf_mask, sym, sym_mask):
			left = col > 0 and layer[row, col-1, HC]
			up = row > 0 and layer[row-1, col, VC]
			right = col < maxcol and layer[row, col, HC]
			down = row < maxrow and layer[row, col, VC]
			upleft = up and left and layer[row-1, col-1, HC] and layer[row-1, col-1, VC]
			upright = up and right and layer[row-1, col, HC] and layer[row-1, col+1, VC]
			downleft = down and left and layer[row+1, col-1, HC] and layer[row, col-1, VC]
			downright = down and right and layer[row+1, col, HC] and layer[row, col+1, VC]

			def draw_quarter(qi, qj, hcon, vcon, hvcon):
				symrow = qi << 1 | qj
				symcol = 4 if hvcon else hcon | vcon << 1
				buf[row, qi, :, col, qj, :] = sym[symrow, :, symcol]
				buf_mask[row, qi, :, col, qj, :] = sym_mask[symrow, :, symcol]

			# draw_cell:
			draw_quarter(0, 0, left, up, upleft)
			draw_quarter(0, 1, right, up, upright)
			draw_quarter(1, 0, left, down, downleft)
			draw_quarter(1, 1, right, down, downright)

		def draw_via():
			buf_via[row, :, col, :] = sym_via
		
		def clear_cell(buf, mask):
			buf[row, :, :, col, :, :] = 0
			mask[row, :, :, col, :, :] = 0
		
		def clear_via():
			buf_via[row, :, col, :] = 0

		# update_range:
		for row in row_range:
			for col in col_range:
				s = chip_silicon[row, col, MS]
				if s:
					i = (s % 5) - 1
					draw_cell(chip_silicon, buf_silicon, buf_mask_silicon, sym_silicon[i], sym_mask_silicon)
				else:
					clear_cell(buf_silicon, buf_mask_silicon)
				if chip_metal[row, col, MS]:
					draw_cell(chip_metal, buf_metal, buf_mask_metal, sym_metal, sym_mask_metal)
				else:
					clear_cell(buf_metal, buf_mask_metal)
				if chip_via[row, col]:
					draw_via()
				else:
					clear_via()

	def update_voltages(self, chip):
		#profile.begin('update_voltages')
		state = chip.state
		voltage_to_rgb(state.Vm, self.v_metal_rgb_bycell)
		voltage_to_rgb(state.Vs, self.v_silicon_rgb_bycell)
		#profile.end('update_voltages')

	def update_images(self):
		#profile.begin('update_images')
		cs = self.cellsize
		nrows, ncols = self.size
		width = ncols * cs
		height = nrows * cs
		size = (width, height)
		self.silicon_image = image_from_ndarray(self.silicon, "RGBA", size)
		self.metal_image = image_from_ndarray(self.metal, "RGBA", size)
		self.via_image = image_from_ndarray(self.via, "RGBA", size)
		self.v_silicon_image = image_from_ndarray(self.v_silicon, "RGBA", size)
		self.v_metal_image = image_from_ndarray(self.v_metal, "RGBA", size)
		#profile.end('update_images')
