
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include "openglheader.h"

#include "utilities.h"
#include "GPUsparsemat.h"

static GLuint program_id;
static GLuint uloc[12];

void LoadGPUMultMMfShader ( void )
{
  static const char *filename[] =
    { "multmmf.comp.glsl" };
  static const GLchar *uname[] =
    { "stage", "prN0", "prN", "prStep",
      "ma", "nnza", "mb", "nprod", "nnzc", "h", "reverse", "tablgt" };
  GLuint shader_id;
  int    i;

  shader_id = CompileShaderFiles ( GL_COMPUTE_SHADER, 1, &filename[0] );
  program_id = LinkShaderProgram ( 1, &shader_id, "GPUMultMMf" );
  for ( i = 0; i < 12; i++ )
    uloc[i] = glGetUniformLocation ( program_id, uname[i] );
  glDeleteShader ( shader_id );
  ExitIfGLError ( "LoadGPUMultMMfShader" );
} /*LoadGPUMultMMfShader*/

void DeleteGPUMultMMfProgram ( void )
{
  glDeleteProgram ( program_id );
} /*DeleteGPUMultMMfProgram*/

#define COMPUTE(SIZEX,SIZEY,SIZEZ) \
  { glDispatchCompute ( SIZEX, SIZEY, SIZEZ ); \
    glMemoryBarrier ( GL_SHADER_STORAGE_BARRIER_BIT ); }
#define EXECSTAGE(STAGE,SIZEX,SIZEY,SIZEZ) \
  { glUniform1i ( uloc[0], STAGE );  COMPUTE ( SIZEX, SIZEY, SIZEZ ) }

static void iPrefixSum ( int N0, int N )
{
  int k, m, d;

  glUniform1i ( uloc[0], 0 );    /* uniform stage = 0; */
  glUniform1ui ( uloc[1], N0 );   /* uniform prN0 = N0; */
  glUniform1ui ( uloc[2], N );    /* uniform prN = N; */
  d = (N+1)/2;
  for ( k = 0, m = N-1;  m > 0;  k++, m >>= 1 ) {
    glUniform1ui ( uloc[3], k );  /* uniform prStep = k; */
    COMPUTE ( d, 1, 1 );
  }
  ExitIfGLError ( "iPrefixSum" );
} /*iPrefixSum*/

static void NetSort ( GLuint nseq, GLuint maxnt )
{
  GLuint steps, nn, h, h2, i, gsize;

  if ( maxnt < 2 )
    return;
  for ( nn = maxnt-1, steps = 0;  nn;  nn >>= 1, steps++ )
    ;
  nn = 1 << steps;  gsize = nn/2;
  glUniform1i ( uloc[0], 6 );         /* uniform stage = 6; */
  glUniform1ui ( uloc[1], gsize );    /* uniform prN0 = nn/2; */
  glUniform1ui ( uloc[2], maxnt );    /* uniform prN = maxnt */
  for ( i = 0, h2 = 1, h = 2;  i < steps;  i++, h2 = h, h += h ) {
    glUniform1i ( uloc[10], GL_TRUE );   /* uniform reverse = true; */
    glUniform1ui ( uloc[9], h );      /* uniform h = h2; */
    COMPUTE ( nseq, gsize, 1 )
    glUniform1i ( uloc[10], GL_FALSE );  /* uniform reverse = false; */
    for ( ;  h2 > 1;  h2 >>= 1 ) {
      glUniform1ui ( uloc[9], h2 );   /* uniform h = h2; */
      COMPUTE ( nseq, gsize, 1 )
    }
  }
  ExitIfGLError ( "NetSort" );
} /*NetSort*/

char GPUMultSparseMatricesf ( GPUSparseMatrix *c,
                              GPUSparseMatrix *a, GPUSparseMatrix *b )
{
  GLuint auxb[5];
  GLuint nprod, _nnzc, maxnt, tablgt, i;

  glUseProgram ( program_id );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 0, a->buf[0] );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 1, a->buf[1] );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 2, b->buf[0] );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 3, b->buf[1] );
  glUniform1ui ( uloc[4], a->m );
  glUniform1ui ( uloc[5], a->nnz );
  glUniform1ui ( uloc[6], a->n );
  glGenBuffers ( 5, auxb );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 4, auxb[0] );
  glBufferData ( GL_SHADER_STORAGE_BUFFER,
                 (a->nnz+1)*sizeof(GLuint), NULL, GL_DYNAMIC_DRAW );
  EXECSTAGE ( 1, a->nnz, 1, 1 )
  iPrefixSum ( 1, a->nnz );
  glGetBufferSubData ( GL_SHADER_STORAGE_BUFFER,
                       a->nnz*sizeof(GLuint), sizeof(GLuint), &nprod );
  if ( !nprod ) {
    glDeleteBuffers ( 5, auxb );
    memset ( c, 0, sizeof(GPUSparseMatrix) );
    return false;
  }
  glUniform1ui ( uloc[7], nprod );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 6, auxb[1] );
  glBufferData ( GL_SHADER_STORAGE_BUFFER,
                 2*nprod*sizeof(GLuint), NULL, GL_DYNAMIC_DRAW );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 5, auxb[2] );
  glBufferData ( GL_SHADER_STORAGE_BUFFER,
                 nprod*sizeof(GLfloat), NULL, GL_DYNAMIC_DRAW );
  EXECSTAGE ( 2, nprod, 1, 1 )
  glUniform1ui ( uloc[11], tablgt = nprod > a->m ? nprod+1 : a->m+1 );
  glDeleteBuffers ( 1, &auxb[0] );
  glGenBuffers ( 1, &auxb[0] );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 4, auxb[0] );
  glBufferData ( GL_SHADER_STORAGE_BUFFER,
                 2*tablgt*sizeof(GLuint), NULL, GL_DYNAMIC_DRAW );
  EXECSTAGE ( 3, a->m, 1, 1 )
  EXECSTAGE ( 4, a->m, 1, 1 );
  glUniform1i ( uloc[0], 5 );  /* uniform stage = 5; */
  for ( i = a->m; i > 1; i = (i+1)/2 ) {
    glUniform1ui ( uloc[2], i );
    COMPUTE ( i/2, 1, 1 );
  }
  glGetBufferSubData ( GL_SHADER_STORAGE_BUFFER,
                       tablgt*sizeof(GLuint), sizeof(GLuint), &maxnt );
  NetSort ( a->m, maxnt );
  EXECSTAGE ( 7, nprod, 1, 1 )
  iPrefixSum ( 1, nprod );
  glGetBufferSubData ( GL_SHADER_STORAGE_BUFFER,
                       nprod*sizeof(GLuint), sizeof(GLuint), &_nnzc );
  glUniform1ui ( uloc[8], _nnzc );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 1, auxb[4] );
  glBufferData ( GL_SHADER_STORAGE_BUFFER,
                 _nnzc*sizeof(GLfloat), NULL, GL_DYNAMIC_DRAW );
  glBindBufferBase ( GL_SHADER_STORAGE_BUFFER, 0, auxb[3] );
  glBufferData ( GL_SHADER_STORAGE_BUFFER,
                 (a->m+1+_nnzc)*sizeof(GLuint), NULL, GL_DYNAMIC_DRAW );
  EXECSTAGE ( 8, nprod+1, 1, 1 )
  EXECSTAGE ( 9, _nnzc, 1, 1 );
  EXECSTAGE ( 10, a->m, 1, 1 )
  c->m = a->m;  c->n = b->n;  c->nnz = _nnzc;  c->lmax = 0;
  c->buf[0] = auxb[3];
  c->buf[1] = auxb[4];
  glUseProgram ( 0 );
  glDeleteBuffers ( 3, auxb );
  ExitIfGLError ( "GPUMultSparseMatricesf" );
  return true;
} /*GPUMultSparseMatricesf*/

