#include "initCUDA.h"

#include <iostream>
#include <stdexcept>
#include <oogl/glIncludes.h>
#include <cuda_gl_interop.h>
#include <utils/log.h>

namespace cudaul {

void initCUDA(unsigned device_number, bool use_opengl) {
	int device_count;
	cudaError_t error;
	cudaDeviceProp device_prop;

	error = cudaGetDeviceCount(&device_count);
	if (error) {
		LOG_ERROR << "Failed getting device count" << std::endl;
		throw std::runtime_error("Failed getting device count");
	}
	if (device_count <= 0) {
		LOG_ERROR << "Did not find suitable CUDA device" << std::endl;
		throw std::runtime_error("Did not find suitable CUDA device");
	}
	if ((unsigned) device_count < device_number) {
		LOG_ERROR << "No such device number: " << device_number << std::endl;
		throw std::runtime_error("No such device number");
	}
	error = cudaGetDeviceProperties(&device_prop, device_number);
	if (error) {
		LOG_ERROR << "Failed getting device property for given device number: " << device_number << std::endl;
		throw std::runtime_error("Failed getting device property for given device number");
	}
	if (device_prop.major < 1) {
		LOG_ERROR << "Device does not support CUDA" << std::endl;
		throw std::runtime_error("Device does not support CUDA");
	}
	if (!use_opengl) {
		LOG_DEBUG << "don't use opengl" << std::endl;
		error = cudaSetDevice(device_number);
	} else {
		LOG_DEBUG << "use opengl interop" << std::endl;
		error = cudaGLSetGLDevice(device_number);
	}
	if (error) {
		LOG_ERROR << "Failed to use device" << std::endl;
		throw std::runtime_error("Failed to use device");
	}

	dumpInfos(device_prop);
}

void initCUDAWithOpenGL(unsigned device_number) {
	initCUDA(device_number,true);
}

void dumpInfos(cudaDeviceProp &device_prop) {
	if(!LOG_IS_INFO_ENABLED)
		return;

	LOG_INFO << std::endl
		<< "Clock Rate: " << device_prop.clockRate << " kHz\n"
		<< "Multiprocessor Count: " << device_prop.multiProcessorCount << std::endl
		<< "Regs per Block: " << device_prop.regsPerBlock << std::endl
		<< "Shared mem per block: " << device_prop.sharedMemPerBlock << std::endl
		<< "Global Memory: " << device_prop.totalGlobalMem / (1024 * 1024) << " MB" << std::endl
		<< "Warp Size: " << device_prop.warpSize << std::endl
		<< "Max Grid Size: " << device_prop.maxGridSize[0] << "x" << device_prop.maxGridSize[1] << "x" << device_prop.maxGridSize[2] << std::endl
		<< "Max Thread Dim: " << device_prop.maxThreadsDim[0] << "x" << device_prop.maxThreadsDim[1] << "x" << device_prop.maxThreadsDim[2] << std::endl
		<< "Max Thread Per Block: " << device_prop.maxThreadsPerBlock << std::endl
		<< "Mapping of host memory is " << (device_prop.canMapHostMemory ? "supported" : "not supported") << std::endl;

}

}
