import os
from time import sleep
from warnings import warn

Import('env')
my_env = env.Clone()

thrust_abspath = os.path.abspath("../../thrust/")

# this function builds a trivial source file from a Thrust header
def trivial_source_from_header(source, target, env):
  target_filename = str(target[0])
  fid = open(target_filename, 'w')

  # make sure we don't trip over <windows.h> when compiling with cl.exe
  if my_env.subst('$CXX') == 'cl':
    fid.write('#include <windows.h>\n')

  for src in source:
    fid.write('#include <' + str(src) + '>\n')
  fid.close()

  # XXX WAR race condition on Windows discussed here:
  #         http://scons.tigris.org/ds/viewMessage.do?dsForumId=1272&dsMessageId=807348
  if os.name == 'nt':
    sleep(0.1)


# CUFile builds a trivial .cu file from a Thrust header
cu_from_header_builder = Builder(action = trivial_source_from_header,
                                 suffix = '.cu',
                                 src_suffix = '.h')
my_env.Append(BUILDERS = {'CUFile' : cu_from_header_builder})

# CPPFile builds a trivial .cpp file from a Thrust header
cpp_from_header_builder = Builder(action = trivial_source_from_header,
                                  suffix = '.cpp',
                                  src_suffix = '.h')
my_env.Append(BUILDERS = {'CPPFile' : cpp_from_header_builder})

# gather all public thrust headers
public_thrust_headers = my_env.RecursiveGlob('*.h', '#thrust', exclude='detail|system')

# omit headers from systems which are not the host or device system
public_thrust_headers.extend(my_env.Glob('*.h', '#thrust/system'))
public_thrust_headers.extend(my_env.RecursiveGlob('*.h', '#thrust/system/' + env['host_backend'], exclude='detail'))
if env['device_backend'] != env['host_backend']:
  public_thrust_headers.extend(my_env.RecursiveGlob('*.h', '#thrust/system/' + env['device_backend'], exclude='detail')) 

sources = []

for hdr in public_thrust_headers:
  rel_path = Dir('#thrust').rel_path(hdr)
  
  # replace slashes with '_slash_'
  src_filename = rel_path.replace('/', '_slash_').replace('\\', '_slash_')

  cu  = my_env.CUFile(src_filename.replace('.h', '.cu'), hdr)
  cpp = my_env.CPPFile(src_filename.replace('.h', '_cpp.cpp'), hdr)

  sources.extend([cu,cpp])

  # ensure that all files #include <thrust/detail/config.h>
  if '#include <thrust/detail/config.h>' not in hdr.get_contents():
    warn('Header ' + str(hdr) + ' does not include <thrust/detail/config.h>')

# generate source files which #include all headers
all_headers_cu  = my_env.CUFile('all_headers.cu', public_thrust_headers)
all_headers_cpp = my_env.CUFile('all_headers_cpp.cpp', public_thrust_headers)

sources.append(all_headers_cu)
sources.append(all_headers_cpp)

# and the file with main()
sources.append('main.cu')

# build the tester
tester = my_env.Program('tester', sources)

# add the tester to the 'run_trivial_tests' alias
tester_alias = my_env.Alias('run_trivial_tests', [tester], tester[0].abspath)

# always build the 'run_trivial_tests' target whether or not it needs it
my_env.AlwaysBuild(tester_alias)

# add the trivial tests alias to the 'run_tests' alias
my_env.Alias('run_tests', [tester], tester[0].abspath)

