#! /usr/bin/python

from cctbx import uctbx, sgtbx
import string, math
from pymol.cgo import *
from pymol import cmd

from all_axes import *

from Numeric import *

def read_symop_dat(sg_input=None):
  linecounter = 0
  axes = {}
  file = file('symop_axes.dat','r')
  lines = file.readlines()
  num_lines=len(lines)
  #for line in lines:
  #for line in file.readlines():
  while (1):
    line = lines[linecounter]
    if line[0] != '#' and line[0] != '':
      #print linecounter, line
      try:
        #(sgsgnum,num_ax,sg) = line.split(None,3)
        (sg,sgnum,num_ax,junk) = line.split(None,3)
      except:
        print linecounter,line
      num_ax = string.atoi(num_ax)
      axes[sg] = []
      for i in range(num_ax):
        linecounter = linecounter + 1
        #sg_line = file.readline()
        sg_line = lines[linecounter].split()
        tmp_dict = {}
        #print linecounter, sg_line
        tmp_dict['start'] = [string.atof(j) for j in sg_line[0:3]]
        tmp_dict['end'] = [string.atof(j) for j in sg_line[3:6]]
        tmp_dict['symb'] = sg_line[6]
        #tmp_dict['type'] = sg_line[7]

        axes[sg].append(tmp_dict)
    linecounter = linecounter + 1
    if linecounter >= num_lines:
      break
  if sg_input != None:
    sg_input = sg_input.upper()
    #for i in range(len(axes[sg_input])):
    #  print axes[sg_input][i]
    return axes[sg_input]
  else:
    return axes

def draw_symbol(start,end,symb,color,radius=0.2):
  degtorad = math.pi/180.
  costhirty = math.cos(30.0*degtorad)
  sinthirty = math.sin(30.0*degtorad)
  symb_obj = []

  if symb == '2' or symb == '2^1':
    pass

  elif symb == '3' or symb == '3^1' or symb == '3^2':
    symb_obj = [ BEGIN, TRIANGLES, COLOR ] + color
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([radius, 0, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([-radius*sinthirty, radius*costhirty, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([-radius*sinthirty, -radius*costhirty, 0]))[0].tolist()

    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([radius, 0, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([-radius*sinthirty, radius*costhirty, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([-radius*sinthirty, -radius*costhirty, 0]))[0].tolist()
    symb_obj.append(END)

  elif symb == '4' or symb == '4^1' or symb == '4^2' or symb == '4^3':
    symb_obj = [ BEGIN, TRIANGLES, COLOR ] + color
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([radius, radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([-radius, radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([-radius, -radius, 0]))[0].tolist()

    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([radius, radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([radius, -radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([start]) + array([-radius, -radius, 0]))[0].tolist()

    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([radius, radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([-radius, radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([-radius, -radius, 0]))[0].tolist()

    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([radius, radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([radius, -radius, 0]))[0].tolist()
    symb_obj.append(VERTEX)
    symb_obj = symb_obj + (array([end]) + array([-radius, -radius, 0]))[0].tolist()
    symb_obj.append(END)

  elif symb == '6' or symb == '6^1' or symb == '6^2' or symb == '6^3' or symb == '6^4' or symb == '6^5':
    pass

  return symb_obj

def draw_symops(cell_param_list,sg,radius=0.2):
  """
  From pymol issue the "run draw_symops.py" command to load the script,
  then issue the "draw_symops((cell_param_list),<SpaceGroup_string>,<optional radius>)" command 
  to actually run it and create the cgo object.
  E.g. "draw_symops((45.2,45.2,70.8,90.,90.,120.),'p3121',0.5)" to generate the symmetry operators
  for this trigonal space group "p 31 2 1"
  The different axis types appear as different objects on the PyMOL menu so they can be turned
  on and off individually.
  """
  U=uctbx.unit_cell((cell_param_list))

#rotation axes
#    "2" "yellow",
#    "3" "orange",
#    "4" "mauve",
#    "6" "purple",

#screw axes (all sub_1 axes are green)
#    "21" "green",
#    "31" "green",
#    "32" "lime",
#    "41" "green",
#    "42" "cyan",
#    "43" "iceblue",
#    "61" "green",
#    "62" "silver",
#    "63" "cyan",
#    "64" "iceblue",
#    "65" "blue",

  color = {
    "2" : [1.0, 1.0, 0.0],
    "3" : [1.0, 0.5, 0.0],
    "4" : [1.0, 0.5, 1.0],
    "6" : [1.0, 0.0, 1.0],
    "2^1" : [0.0, 1.0, 0.0],
    "3^1" : [0.0, 1.0, 0.0],
    "3^2" : [0.5, 1.0, 0.5],
    "4^1" : [0.0, 1.0, 0.0],
    "4^2" : [0.0, 1.0, 1.0],
    "4^3" : [0.5, 0.5, 1.0],
    "6^1" : [0.0, 1.0, 0.0],
    "6^2" : [0.8, 0.8, 0.8],
    "6^3" : [0.0, 1.0, 1.0],
    "6^4" : [0.5, 0.5, 1.0],
    "6^5" : [0.0, 0.0, 1.0],
    }

  sg = sg.upper()
  symop_axes = read_symop_dat(sg)
  #symop_axes = get_all_axes(sg)

  #CYLINDER = 'CYLINDER'
  ax_obj = {}
  #vert_obj = []
  for i in range(len(symop_axes)):
    start = map(None,U.orthogonalize(symop_axes[i]['start']))
    end = map(None,U.orthogonalize(symop_axes[i]['end']))
    color_ax = color[symop_axes[i]['symb']]
    symb_ax = symop_axes[i]['symb']

    if ax_obj.has_key(symb_ax):
      ax_obj[symb_ax].append(CYLINDER)
    else:
      ax_obj[symb_ax] = [CYLINDER]

    ax_obj[symb_ax] = ax_obj[symb_ax] + start + end + [radius]
    ax_obj[symb_ax] = ax_obj[symb_ax] + color[symb_ax] + color[symb_ax]
    ax_obj[symb_ax] = ax_obj[symb_ax] + draw_symbol(start,end,symb_ax,color[symb_ax],radius*3.)

  for key in ax_obj.keys():
    name=sg + "_" + key
    cmd.load_cgo(ax_obj[key],name)
  #return ax_obj

cmd.extend("draw_symops",draw_symops)
