Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/base/flamec/include/FLA_lapack_prototypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ FLA_Error FLA_Svdd_external( FLA_Svd_type jobz, FLA_Obj A, FLA_Obj s, FLA_Obj U,

// --- external HIP prototypes -------------------------------------------------
#ifdef FLA_ENABLE_HIP
FLA_Error FLA_Apply_pivots_unb_external_hip( rocblas_handle handle, FLA_Side side, FLA_Trans trans, FLA_Obj p, FLA_Obj A, void* A_hip );
FLA_Error FLA_Apply_pivots_ln_unb_ext_hip( rocblas_handle handle, FLA_Obj p, FLA_Obj A, void* A_hip );
FLA_Error FLA_Apply_Q_blk_external_hip( rocblas_handle handle, FLA_Side side, FLA_Trans trans, FLA_Store storev, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip, FLA_Obj B, void* B_hip );
FLA_Error FLA_Bidiag_apply_U_external_hip( rocblas_handle handle, FLA_Side side, FLA_Trans trans, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip, FLA_Obj B, void* B_hip );
FLA_Error FLA_Bidiag_apply_V_external_hip( rocblas_handle handle, FLA_Side side, FLA_Trans trans, FLA_Obj A, void* A_hip, FLA_Obj t, void* t_hip, FLA_Obj B, void* B_hip );
Expand Down Expand Up @@ -249,6 +251,7 @@ FLA_Error FLA_Trinv_ln_blk_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hi
FLA_Error FLA_Trinv_lu_blk_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip );
FLA_Error FLA_Trinv_un_blk_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip );
FLA_Error FLA_Trinv_uu_blk_ext_hip( rocblas_handle handle, FLA_Obj A, void* A_hip );
FLA_Error FLA_Trsm_piv_external_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip, FLA_Obj p );
#endif

// --- check routine prototypes ------------------------------------------------
Expand Down
15 changes: 15 additions & 0 deletions src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ void FLASH_Queue_exec_task_hip( FLASH_Task* t,
typedef FLA_Error(*flash_eig_gest_hip_p)(rocblas_handle handle, FLA_Inv inv, FLA_Uplo uplo, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip );
typedef FLA_Error(*flash_lu_piv_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj p );
typedef FLA_Error(*flash_lu_piv_copy_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj p, FLA_Obj U, void* U_hip );
typedef FLA_Error(*flash_trsm_piv_hip_p)(rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip, FLA_Obj p );

// Level-3 BLAS
typedef FLA_Error(*flash_gemm_hip_p)(rocblas_handle handle, FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip, FLA_Obj beta, FLA_Obj C, void* C_hip);
Expand Down Expand Up @@ -562,6 +563,20 @@ void FLASH_Queue_exec_task_hip( FLASH_Task* t,
t->output_arg[1],
output_arg[1] );
}
// FLA_Trsm_piv
else if ( t->func == (void *) FLA_Trsm_piv_task )
{
flash_trsm_piv_hip_p func;
func = (flash_trsm_piv_hip_p) FLA_Trsm_piv_external_hip;

func( handle,
t->input_arg[0],
input_arg[0],
t->output_arg[0],
output_arg[0],
t->fla_arg[0]
);
}
// FLA_Gemm
else if ( t->func == (void *) FLA_Gemm_task )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ also to create a macro for when it is not below to return an error code.
(void *) cntl, \
"Trsm ", \
FALSE, \
FALSE, \
TRUE, \
0, 1, 1, 1, \
p, A, C )

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*

Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2023, Advanced Micro Devices, Inc.

This file is part of libflame and is available under the 3-Clause
BSD license, which can be found in the LICENSE file at the top-level
directory, or at http://opensource.org/licenses/BSD-3-Clause

*/

#include "FLAME.h"

#ifdef FLA_ENABLE_HIP

#include "hip/hip_runtime_api.h"
#include "rocblas/rocblas.h"
#include "rocsolver/rocsolver.h"

FLA_Error FLA_Apply_pivots_unb_external_hip( rocblas_handle handle, FLA_Side side, FLA_Trans trans, FLA_Obj p, FLA_Obj A, void* A_hip )
{
FLA_Datatype datatype;
int n_A, cs_A;
int m_p;
int inc_p;
int* buff_p;
int k1_1, k2_1;
int* pivots_lapack;
int i;

if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
FLA_Apply_pivots_check( side, trans, p, A );

if ( FLA_Obj_has_zero_dim( A ) ) return FLA_SUCCESS;

datatype = FLA_Obj_datatype( A );

n_A = FLA_Obj_width( A );
cs_A = FLA_Obj_col_stride( A );

inc_p = FLA_Obj_vector_inc( p );
m_p = FLA_Obj_vector_dim( p );

buff_p = FLA_INT_PTR( p );

// Use one-based indices for LAPACK.
k1_1 = 1;
k2_1 = m_p;

// Translate FLAME pivot indices to LAPACK-compatible indices. It is
// important to note that this conversion, unlike the one done by
// FLA_Shift_pivots_to(), is NOT in-place, but rather done separately
// in a temporary buffer.
hipMallocManaged( (void**) &pivots_lapack, m_p * sizeof( int ), hipMemAttachGlobal );
// this implies a sync

for ( i = 0; i < m_p; i++ )
{
pivots_lapack[ i ] = buff_p[ i ] + i + 1;
}

void* A_mat = NULL;
if ( FLASH_Queue_get_malloc_managed_enabled_hip() )
{
A_mat = FLA_Obj_buffer_at_view( A );
}
else
{
A_mat = A_hip;
}

switch ( datatype ){

case FLA_FLOAT:
{
float* buff_A = ( float * ) A_mat;

rocsolver_slaswp( handle,
n_A,
buff_A, cs_A,
k1_1,
k2_1,
pivots_lapack,
inc_p );
break;
}

case FLA_DOUBLE:
{
double* buff_A = ( double * ) A_mat;

rocsolver_dlaswp( handle,
n_A,
buff_A, cs_A,
k1_1,
k2_1,
pivots_lapack,
inc_p );
break;
}

case FLA_COMPLEX:
{
rocblas_float_complex* buff_A = ( rocblas_float_complex * ) A_mat;

rocsolver_claswp( handle,
n_A,
buff_A, cs_A,
k1_1,
k2_1,
pivots_lapack,
inc_p );
break;
}

case FLA_DOUBLE_COMPLEX:
{
rocblas_double_complex* buff_A = ( rocblas_double_complex * ) A_mat;

rocsolver_zlaswp( handle,
n_A,
buff_A, cs_A,
k1_1,
k2_1,
pivots_lapack,
inc_p );
break;
}

}

hipFree( pivots_lapack );

return FLA_SUCCESS;
}

FLA_Error FLA_Apply_pivots_ln_unb_ext_hip( rocblas_handle handle, FLA_Obj p, FLA_Obj A, void* A_hip )
{
return FLA_Apply_pivots_unb_external_hip( handle, FLA_LEFT, FLA_NO_TRANSPOSE, p, A, A_hip );
}

#endif
31 changes: 31 additions & 0 deletions src/base/flamec/wrappers/lapack/hip/FLA_Trsm_piv_external_hip.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*

Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2023, Advanced Micro Devices, Inc.

This file is part of libflame and is available under the 3-Clause
BSD license, which can be found in the LICENSE file at the top-level
directory, or at http://opensource.org/licenses/BSD-3-Clause

*/

#include "FLAME.h"

#ifdef FLA_ENABLE_HIP

#include "hip/hip_runtime_api.h"
#include "rocblas/rocblas.h"

FLA_Error FLA_Trsm_piv_external_hip( rocblas_handle handle, FLA_Obj A, void* A_hip, FLA_Obj B, void* B_hip, FLA_Obj p )
{
FLA_Apply_pivots_unb_external_hip( handle, FLA_LEFT, FLA_NO_TRANSPOSE,
p, B , B_hip);

FLA_Trsm_external_hip( handle, FLA_LEFT, FLA_LOWER_TRIANGULAR,
FLA_NO_TRANSPOSE, FLA_UNIT_DIAG,
FLA_ONE, A, A_hip, B, B_hip);

return FLA_SUCCESS;
}

#endif