#--------------------------------------------------------------
#
#   SimChip2 - Chip
#
#--------------------------------------------------------------

from __future__ import division
from weakref import WeakKeyDictionary
from numpy import uint8, sum
from GUI import Model
from GUI.Geometry import sect_rect
from constants import steps_per_frame
from simulation import ChipStructure, ChipState, MET, SIL, HC, VC, VIA, \
	logic_level_of_pad, apply_logic_level_to_pad, simulation_timestep, Vcc, \
	AUTO_OFF, AUTO_PAUSE, AUTO_RUN
from rendering import invalidate_render_buffer
from waveforms import WaveformSet
from undo import get_undo_manager
import testing, www

input_devices = ('switch',)
output_devices = ('led',)
supply_devices = ('terminal',)
input_and_output_devices = input_devices + output_devices
input_and_supply_devices = input_devices + supply_devices
speed_names = ["Slow", "Fast"]
voltage_wave_values = {'0': 0.0, '1': Vcc}
expectation_values = {'0': 0, '1': 1, 'L': 2, 'H': 3, 'X': 4}

#--------------------------------------------------------------

class ChipParams(object):
	#  size        (nrows, ncols)
	#  num_pins    int
	#  pin_specs   {int: (kind, name)}    Default pin setups
	#  speed       int                    Initial game speed
	#  waveforms   {int: string}          Waveform specs
	#  wave_order  [pin_index]
	#  vcc_pin     int
	#  wavespec_resolution   int
	
	speed = 1
	wave_order = []
	wavespec_resolution = None
	
	def __init__(self, size, num_pins, pin_specs):
		self.size = size
		self.num_pins = num_pins
		self.pin_specs = pin_specs
		self.waveforms = {}
		#self.wave_order = range(num_pins)
		self.vcc_pin = num_pins - 1
		for i, (kind, name) in pin_specs.iteritems():
			if kind == 'vcc':
				self.vcc_pin = i
				break

#--------------------------------------------------------------

spec_to_device_state = {
	"in": ('switch', 0),
	"out": ('led', 0),
	"vcc": ('terminal', 1),
	"gnd": ('terminal', 0),
	"nc": (None, 0),
}

#--------------------------------------------------------------

class PinSetup(object):
	#  device   string
	#  state    int
	#  name     string

	def __init__(self, i, params):
		spec = params.pin_specs.get(i)
		if spec:
			kind, name = spec
		else:
			name = ""
			n = params.num_pins
			if i == n // 2 - 1:
				name = kind = "gnd"
			elif i == n - 1:
				name = kind = "vcc"
			elif i >= n // 2:
				kind = "out"
			else:
				kind = "in"
		if not name:
			name = "PIN%d" % (i + 1)
		self.device, self.state = spec_to_device_state[kind]
		self.name = name.upper()
	
	def is_input(self):
		return self.device in input_devices

#--------------------------------------------------------------

class ChipSetup(object):
	#  name           string
	#  pins           [PinSetup]
	#  locked         boolean
	#  waveforms      WaveformSet
	#  expectations   WaveformSet
	
	def __init__(self, name, params, locked = False):
		pin_range = xrange(params.num_pins)
		self.name = name
		self.pins = [PinSetup(i, params) for i in pin_range]
		self.locked = locked
		self.create_default_waveforms(params)
	
	def create_default_waveforms(self, params = None):
		waveforms = WaveformSet()
		expectations = WaveformSet()
		if params:
			wave_order = params.wave_order
		else:
			wave_order = xrange(len(self.pins))
		for i in wave_order:
			pin = self.pins[i]
			device = pin.device
			spec = params and params.waveforms.get(i)
			spec_res = params and params.wavespec_resolution
			if device in input_devices:
				waveforms.add_waveform_for_pin(i, spec, voltage_wave_values, spec_res)
			elif device in output_devices:
				waveforms.add_waveform_for_pin(i)
				expectations.add_waveform_for_pin(i, spec, expectation_values, spec_res,
					uint8)
		self.waveforms = waveforms
		self.expectations = expectations
	
	def __len__(self):
		return len(self.pins)
	
	def __getitem__(self, i):
		return self.pins[i]
	
	def __iter__(self):
		return iter(self.pins)
	
	def __setstate__(self, d):
		if 'locked' not in d:
			d['locked'] = d['name'] == "Test"
		self.__dict__.update(d)
		if 'waveforms' not in d:
			self.create_default_waveforms()
	
	def get_input_time_range(self):
		t = 0
		for track in self.waveforms.tracks:
			if self.pins[track.pin_index].is_input():
				t = max(t, track.get_end_time())
		return t

#--------------------------------------------------------------

state_cache = WeakKeyDictionary()

def get_chip_state(chip):
	state = state_cache.get(chip)
	if not state:
		state = ChipState(chip)
		state_cache[chip] = state
	return state

#--------------------------------------------------------------

class Chip(Model):
	#  level_key       string
	#  setup_list      [ChipSetup]
	#  user_setup      ChipSetup        or None if "Test" setup selected
	#  structure       ChipStructure
	#  pads            [(row,col)]      Pad centres, indexed by pin no.
	#  game_speed      int
	#  selected_rect        rect or None
	#  floating_selection   Snippet or None
	
	level_key = None
	game_speed = 0
	selected_rect = None
	floating_selection = None
	pad_height = 3
	
	size = property(lambda self: self.structure.size)
	num_rows = property(lambda self: self.size[0])
	num_cols = property(lambda self: self.size[1])
	state = property(lambda self: get_chip_state(self))
	current_waveforms = property(lambda self: self.current_setup.waveforms)
	current_setup = property(lambda self: self.get_current_setup())
	game = property(lambda self: self.parent.game)
	level = property(lambda self: self.game.level_for_key(self.level_key))
	auto_running = property(lambda self: self.state.auto_mode <> AUTO_OFF)
	
	def __init__(self, parent, key, params):
		self.parent = parent
		self.level_key = key
		num_pins = params.num_pins
		min_rows = 3 * (num_pins // 2)
		if params.size:
			nrows, ncols = params.size
			ncols += 6
		else:
			nrows = min_rows
			ncols = min(nrows + 6, 30)
		self.num_pins = num_pins
		self.pad_height = nrows // (num_pins // 2)
		self.setup_list = [
			ChipSetup("Lab", params),
		]
		self.user_setup = self.setup_list[0]
		self.structure = ChipStructure(nrows, ncols, self.pad_height)
		self.create_pads(num_pins)
		self.game_speed = params.speed
	
	def __setstate__(self, d):
		if 'user_setup' not in d:
			d['user_setup'] = d['current_setup']
			d['setup_list'] = [setup for setup in d['setup_list']
				if setup.name <> "Test"]
		self.__dict__.update(d)

	def num_pins_on_side(self, s):
		n = self.num_pins
		h = n // 2
		if s == 0:
			return h
		else:
			return n - h

	def index_of_pin_on_side(self, s, i):
		if s == 0:
			return i
		else:
			return self.num_pins - 1 - i

	def change_input_state(self, pin_index, setup, new_state):
		setup.state = new_state
		self.changed()
		self.notify('device_state_changed', pin_index)

	def notify(self, *args):
		self.parent.notify(*args)
	
	def name_of_pin(self, pin_index):
		return self.current_setup.pins[pin_index].name

	def logic_level_of_pin(self, pin_index):
		return logic_level_of_pad(self.state, self.pads[pin_index])

	def get_current_setup(self):
		return self.user_setup or self.get_test_setup()

	def get_test_setup(self):
		return self.game.level_for_key(self.level_key).test_setup

	def get_setup_names(self):
		return [setup.name for setup in self.iter_setups()]
	
	def iter_setups(self):
		for setup in self.setup_list:
			yield setup
		setup = self.get_test_setup()
		if setup:
			yield setup

	def select_setup_with_name(self, name):
		new_setup = None
		for setup in self.setup_list:
			if setup.name == name:
				new_setup = setup
				break
		if self.user_setup is not new_setup:
			self.user_setup = new_setup
			self.notify('selected_setup_changed')
	
	def select_test_setup(self):
		self.select_setup_with_name("Test")
	
	def test_setup_selected(self):
		return self.user_setup is None

	def create_pads(self, num_pins):
		self.pads = [None] * num_pins
		col0 = 1
		col1 = self.num_cols - 2
		avail_rows = self.num_rows
		max_per_side = avail_rows // 3
		num_per_side = min(max_per_side, (num_pins + 1) // 2)
		for i in xrange(num_per_side):
			row = (i * avail_rows) // num_per_side + 1
			self.create_pad(row, col0, i)
			self.create_pad(row, col1, num_pins - i - 1)
	
	def create_pad(self, row, col, pin_no):
		self.pads[pin_no] = (row, col)
		cells = self.structure.cells
		for i in xrange(row-1, row+2):
			for j in xrange(col-1, col+2):
				cells[i, j, MET] = 1
				if i < row+1:
					cells[i, j, MET+VC] = 1
				if j < col+1:
					cells[i, j, MET+HC] = 1

	def add_or_remove_material(self, coords, layer, value):
		row, col = coords
		cells = self.structure.cells
		ncols =  cells.shape[1]
		if not (3 <= col < ncols - 3):
			return False
		cell = cells[coords]
		if (value == 0) ^ (cell[layer] == 0) \
				and not (layer == VIA and not (cell[MET] and cell[SIL])):
			self.preserve_cell_for_undo(coords)
			cell[layer] = value
			if layer <> VIA and value == 0:
				cell[VIA] = 0
				self.set_connection((row, col-1), layer, HC, 0)
				self.set_connection((row-1, col), layer, VC, 0)
				self.set_connection((row, col), layer, HC, 0)
				self.set_connection((row, col), layer, VC, 0)
			self.cell_changed(coords)
			return True
		else:
			return False
	
	def add_connection(self, coords, layer, direction):
		if self.set_connection(coords, layer, direction, 1):
			self.cell_changed(coords)
			return True
		else:
			return False

	def set_connection(self, coords, layer, direction, value):
		row, col = coords
		cells = self.structure.cells
		ncols =  cells.shape[1]
		if layer == MET:
			if direction == HC:
				ok = (3 <= col < ncols - 4 or
					(col == 2 and cells[row, col, MET]) or
					(col == ncols - 4 and cells[row, col+1, MET]))
			else:
				ok = (3 <= col < ncols - 3)
		else:
			if direction == HC:
				ok = (3 <= col < ncols - 4)
			else:
				ok = (3 <= col < ncols - 3)
		if not ok:
			return
		if row >= 0 and col >= 0:
			cell = cells[coords]
			index = layer + direction	
			if cell[index] <> value:
				self.preserve_cell_for_undo(coords)
				cell[index] = value
				return True
		return False
	
	def preserve_cell_for_undo(self, coords):
		undo = get_undo_manager(self)
		undo.preserve_cell(self, coords)
	
	def preserve_rect_for_undo(self, rect):
		t, l, b, r = sect_rect(rect, self.structure.bounds)
		l = max(0, l - 1)
		t = max(0, t - 1)
		undo = get_undo_manager(self)
		undo.preserve_rect(self, (t, l, b, r))
	
	def preserve_selection_for_undo(self):
		undo = get_undo_manager(self)
		undo.preserve_selection(self)
	
	def finish_editing(self):
		undo = get_undo_manager(self)
		undo.finish_editing()
	
	def set_cell(self, coords, cell):
		self.structure.cells[coords] = cell
		self.cell_changed(coords)
	
	def set_rect(self, rect, cells):
		self.structure.set_rect(rect, cells)
		self.cells_changed(rect)
	
	def cell_changed(self, coords):
		self.changed()
		rowc, colc = coords
		nrows, ncols = self.size
		for row in xrange(max(0, rowc-1), min(nrows, rowc+2)):
			for col in xrange(max(0, colc-1), min(ncols, colc+2)):
				icoords = (row, col)
				invalidate_render_buffer(self, icoords)
				self.notify('cell_changed', icoords)

	def cells_changed(self, r):
		self.changed()
		row1, col1, row2, col2 = r
		nrows, ncols = self.size
		for row in xrange(max(0, row1-1), min(nrows, row2+1)):
			for col in xrange(max(0, col1-1), min(ncols, col2+1)):
				icoords = (row, col)
				invalidate_render_buffer(self, icoords)
				self.notify('cell_changed', icoords)

	def begin_frame(self):
		self.apply_input_voltages()
		self.measure_output_voltages()
		self.auto_shutoff_check()
		simulation_timestep(self, steps_per_frame[self.game_speed])
		self.notify('chip_state_changed')

	def apply_input_voltages(self):
		chip_setup = self.current_setup
		pads = self.pads
		chip_state = self.state
		waveforms = chip_setup.waveforms
		auto = chip_state.auto_mode
		t = chip_state.time
		for i, pin_setup in enumerate(chip_setup.pins):
			device = pin_setup.device
			is_input = device in input_devices
			is_supply = device in supply_devices
			if is_input or is_supply:
				chip_state.set_pad_impedance(pads[i], 0)
				if is_supply or not auto:
					v = Vcc * pin_setup.state
				else:
					track = waveforms.track_for_pin(i)
					v = track.sample_at_time(t) if track else 0.0
					#if i == 1: ###
					#	print "apply_input_voltages:", v, "to pin", i, "at t =", t ###
				chip_state.apply_voltage_to_pad(v, pads[i])
			else:
				chip_state.set_pad_impedance(pads[i], 1)

	def measure_output_voltages(self):
		state = self.state
		if state.auto_mode:
			t = state.time
			setup = self.current_setup
			waveforms = setup.waveforms
			pads = self.pads
			for i, pin_setup in enumerate(setup.pins):
				if pin_setup.device not in input_devices:
					track = waveforms.track_for_pin(i)
					if track:
						v = state.voltage_of_pad(pads[i])
						#if i == 0: ###
						#	print "measure_output_voltages:", "%3.1f" % v, "on pin", i, "at t =", t ###
						track.set_sample_at_time(t, v)
	
	def auto_stop(self):
		self.set_auto_mode(AUTO_OFF)

	def auto_start(self):
		self.set_auto_mode(AUTO_RUN)

	def set_auto_mode(self, mode):
		state = self.state
		if state.auto_mode <> mode:
			if mode == AUTO_RUN and state.auto_mode == AUTO_OFF:
				state.time = 0
			if state.auto_mode <> AUTO_RUN:
				state.auto_end_time = self.current_setup.get_input_time_range()
			state.auto_mode = mode
			self.notify('auto_mode_changed')

	def auto_shutoff_check(self):
		state = self.state
		if state.auto_mode == AUTO_RUN:
			if state.time >= state.auto_end_time:
				self.set_auto_mode(AUTO_OFF)
				state.time = 0
				if self.test_setup_selected():
					results = testing.check_test_results(self.level)
					self.game.show_test_results()
					www.submit_results(self, results)

	def cut(self, r):
		snip = self.copy(r)
		self.clear(r)
		return snip

	def copy(self, r):
		return self.structure.copy(r)

	def paste(self, snip):
		self.preserve_rect_for_undo(snip.bounds)
		r = self.structure.paste(snip)
		self.cells_changed(r)

	def clear(self, r):
		self.preserve_rect_for_undo(r)
		self.structure.clear(r)
		self.cells_changed(r)
	
	def cut_snippet(self):
		flt = self.float_selection()
		if flt:
			self.invalidate_selection()
			self.floating_selection = None
			return flt

	def copy_snippet(self):
		flt = self.floating_selection
		if flt:
			return flt
		else:
			sel = self.selected_rect
			if sel:
				return self.copy(sel)
	
	def paste_snippet(self, snip):
		self.preserve_selection_for_undo()
		self.invalidate_selection()
		self.floating_selection = snip
		self.selected_rect = None
		self.invalidate_selection()
		self.changed()
	
	def transform_selection(self, func):
		flt = self.float_selection()
		if flt:
			self.floating_selection = func(flt)
			self.changed()
			self.invalidate_selection()

	def flip_horizontal(self):
		self.transform_selection(lambda snip: snip.flip_horizontal())
	
	def flip_vertical(self):
		self.transform_selection(lambda snip: snip.flip_vertical())
	
	def rotate_left(self):
		self.transform_selection(lambda snip: snip.rotate(1))
	
	def rotate_right(self):
		self.transform_selection(lambda snip: snip.rotate(3))
	
	def invalidate_cells(self, r):
		self.notify('cells_changed', r)
	
	def invalidate_snippet(self, snip):
		self.invalidate_cells(snip.bounds)
	
	def any_selection(self):
		return bool(self.floating_selection or self.selected_rect)

	def select_rect(self, sel):
		if not self.any_selection():
			self.preserve_selection_for_undo()
		self.drop_selection()
		if sel:
			self.selected_rect = sel
			self.invalidate_cells(sel)

	def float_selection(self, copy = False):
		flt = self.floating_selection
		if flt:
			self.preserve_selection_for_undo()
		else:
			sel = self.selected_rect
			if sel:
				if copy:
					flt = self.copy(sel)
				else:
					flt = self.cut(sel)
				self.floating_selection = flt
				self.selected_rect = None
				self.changed()
		return flt
	
	def begin_selection_drag(self, copy = False):
		flt = self.floating_selection
		if not flt:
			self.float_selection(copy)
		elif copy:
			self.paste(flt)
			self.changed()			

	def position_floating_selection(self, p):
		flt = self.floating_selection
		if not flt:
			flt = self.float_selection()
		if flt and flt.position <> p:
			self.invalidate_snippet(flt)
			flt.position = p
			self.invalidate_snippet(flt)
			self.changed()

	def drop_selection(self):
		self.invalidate_selection()
		flt = self.floating_selection
		if flt:
			self.paste(flt)
			self.floating_selection = None
			self.changed()
		self.selected_rect = None

	def invalidate_selection(self):
		sel = self.get_selection_rect()
		if sel:
			self.invalidate_cells(sel)

	def get_selection_rect(self):
		flt = self.floating_selection
		if flt:
			return flt.bounds
		else:
			return self.selected_rect

	def measure_silicon_area(self):
		return self.structure.measure_silicon_area()

	def reset_supply_metering(self):
		self.state.W[...] = 0
	
	def measure_supply_current(self, pin_no, time):
		if pin_no >= 0:
			i, j = self.pads[pin_no]
			return sum(self.state.W[j > 2, i-1:i+1]) / time

