Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions stan/math/fwd/mat/fun/trace_quad_form.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,39 @@ namespace stan {
namespace math {

template<int RA, int CA, int RB, int CB, typename T>
inline stan::math::fvar<T>
trace_quad_form(const Eigen::Matrix<stan::math::fvar<T>, RA, CA> &A,
const Eigen::Matrix<stan::math::fvar<T>, RB, CB> &B) {
using stan::math::multiply;
using stan::math::multiply;
stan::math::check_square("trace_quad_form", "A", A);
stan::math::check_multiplicable("trace_quad_form",
"A", A,
"B", B);
return stan::math::trace(multiply(stan::math::transpose(B),
multiply(A, B)));
inline fvar<T>
trace_quad_form(const Eigen::Matrix<fvar<T>, RA, CA> &A,
const Eigen::Matrix<fvar<T>, RB, CB> &B) {
check_square("trace_quad_form", "A", A);
check_multiplicable("trace_quad_form",
"A", A,
"B", B);
return trace(multiply(transpose(B),
multiply(A, B)));
}

template<int RA, int CA, int RB, int CB, typename T>
inline fvar<T>
trace_quad_form(const Eigen::Matrix<fvar<T>, RA, CA> &A,
const Eigen::Matrix<double, RB, CB> &B) {
check_square("trace_quad_form", "A", A);
check_multiplicable("trace_quad_form",
"A", A,
"B", B);
return trace(multiply(transpose(B),
multiply(A, B)));
}

template<int RA, int CA, int RB, int CB, typename T>
inline fvar<T>
trace_quad_form(const Eigen::Matrix<double, RA, CA> &A,
const Eigen::Matrix<fvar<T>, RB, CB> &B) {
check_square("trace_quad_form", "A", A);
check_multiplicable("trace_quad_form",
"A", A,
"B", B);
return trace(multiply(transpose(B),
multiply(A, B)));
}
}
}
Expand Down
71 changes: 71 additions & 0 deletions test/unit/math/fwd/mat/fun/trace_quad_form_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,77 @@ TEST(AgradFwdMatrixTraceQuadForm, mat_fd) {
EXPECT_FLOAT_EQ(16126, res.d_);
}

TEST(AgradFwdMatrixTraceQuadForm, mat_d_mat_fd) {
using stan::math::trace_quad_form;
using stan::math::matrix_fd;

Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ad(4,4);
matrix_fd bd(4,2);
fvar<double> res;
bd << 100, 10,
0, 1,
-3, -3,
5, 2;
ad << 2.0, 3.0, 4.0, 5.0,
6.0, 10.0, 2.0, 2.0,
7.0, 2.0, 7.0, 1.0,
8.0, 2.0, 1.0, 112.0;

bd(0,0).d_ = 1.0;
bd(0,1).d_ = 1.0;
bd(1,0).d_ = 1.0;
bd(1,1).d_ = 1.0;
bd(2,0).d_ = 1.0;
bd(2,1).d_ = 1.0;
bd(3,0).d_ = 1.0;
bd(3,1).d_ = 1.0;

// fvar<double> - fvar<double>
res = trace_quad_form(ad,bd);
EXPECT_FLOAT_EQ(26758, res.val_);
EXPECT_FLOAT_EQ(5622, res.d_);
}

TEST(AgradFwdMatrixTraceQuadForm, mat_fd_mat_d) {

using stan::math::trace_quad_form;
using stan::math::matrix_fd;

matrix_fd ad(4,4);
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> bd(4,2);
fvar<double> res;
bd << 100, 10,
0, 1,
-3, -3,
5, 2;
ad << 2.0, 3.0, 4.0, 5.0,
6.0, 10.0, 2.0, 2.0,
7.0, 2.0, 7.0, 1.0,
8.0, 2.0, 1.0, 112.0;

ad(0,0).d_ = 1.0;
ad(0,1).d_ = 1.0;
ad(0,2).d_ = 1.0;
ad(0,3).d_ = 1.0;
ad(1,0).d_ = 1.0;
ad(1,1).d_ = 1.0;
ad(1,2).d_ = 1.0;
ad(1,3).d_ = 1.0;
ad(2,0).d_ = 1.0;
ad(2,1).d_ = 1.0;
ad(2,2).d_ = 1.0;
ad(2,3).d_ = 1.0;
ad(3,0).d_ = 1.0;
ad(3,1).d_ = 1.0;
ad(3,2).d_ = 1.0;
ad(3,3).d_ = 1.0;

// fvar<double> - fvar<double>
res = trace_quad_form(ad,bd);
EXPECT_FLOAT_EQ(26758, res.val_);
EXPECT_FLOAT_EQ(10504, res.d_);
}

TEST(AgradFwdMatrixTraceQuadForm, mat_ffd) {
using stan::math::trace_quad_form;
using stan::math::matrix_ffd;
Expand Down
94 changes: 94 additions & 0 deletions test/unit/math/mix/mat/fun/trace_quad_form_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include <stan/math/fwd/mat/fun/trace_quad_form.hpp>
#include <stan/math/prim/scal/err/check_not_nan.hpp>
#include <stan/math/rev/scal/fun/value_of_rec.hpp>
#include <stan/math/fwd/scal/fun/value_of_rec.hpp>
#include <stan/math/fwd/mat/fun/typedefs.hpp>
#include <stan/math/mix/mat/fun/typedefs.hpp>
#include <test/unit/math/rev/mat/fun/util.hpp>
Expand Down Expand Up @@ -110,6 +113,97 @@ TEST(AgradMixMatrixTraceQuadForm, mat_fv_1st_deriv) {
EXPECT_FLOAT_EQ(576,h[23]);
}

TEST(AgradMixMatrixTraceQuadForm, mat_dfv_instant) {
using stan::math::trace_quad_form;
using stan::math::matrix_fv;
using stan::math::check_not_nan;

Eigen::Matrix<double, -1, -1> ad(4,4);
matrix_fv bd(4,2);
fvar<var> res;
bd << 100, 10,
0, 1,
-3, -3,
5, 2;
ad << 2.0, 3.0, 4.0, 5.0,
6.0, 10.0, 2.0, 2.0,
7.0, 2.0, 7.0, 1.0,
8.0, 2.0, 1.0, 112.0;


// fvar<var> - fvar<var>
res = trace_quad_form(ad,bd);
EXPECT_NO_THROW(check_not_nan("trace_quad_form","res",res));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some sort of expect test for these new test cases.


TEST(AgradMixMatrixTraceQuadForm, mat_fvd_instant) {
using stan::math::trace_quad_form;
using stan::math::matrix_fv;
using stan::math::check_not_nan;

matrix_fv ad(4,4);
Eigen::Matrix<double, -1, -1> bd(4,2);
fvar<var> res;
bd << 100, 10,
0, 1,
-3, -3,
5, 2;
ad << 2.0, 3.0, 4.0, 5.0,
6.0, 10.0, 2.0, 2.0,
7.0, 2.0, 7.0, 1.0,
8.0, 2.0, 1.0, 112.0;


// fvar<var> - fvar<var>
res = trace_quad_form(ad,bd);
EXPECT_NO_THROW(check_not_nan("trace_quad_form","res",res));
}

TEST(AgradMixMatrixTraceQuadForm, mat_dffv_instant) {
using stan::math::trace_quad_form;
using stan::math::matrix_ffv;
using stan::math::check_not_nan;

Eigen::Matrix<double, -1, -1> ad(4,4);
matrix_ffv bd(4,2);
fvar<fvar<var> > res;
bd << 100, 10,
0, 1,
-3, -3,
5, 2;
ad << 2.0, 3.0, 4.0, 5.0,
6.0, 10.0, 2.0, 2.0,
7.0, 2.0, 7.0, 1.0,
8.0, 2.0, 1.0, 112.0;


// fvar<var> - fvar<var>
res = trace_quad_form(ad,bd);
EXPECT_NO_THROW(check_not_nan("trace_quad_form","res",res));
}

TEST(AgradMixMatrixTraceQuadForm, mat_ffvd_instant) {
using stan::math::trace_quad_form;
using stan::math::matrix_ffv;
using stan::math::check_not_nan;

matrix_ffv ad(4,4);
Eigen::Matrix<double, -1, -1> bd(4,2);
fvar<fvar<var> > res;
bd << 100, 10,
0, 1,
-3, -3,
5, 2;
ad << 2.0, 3.0, 4.0, 5.0,
6.0, 10.0, 2.0, 2.0,
7.0, 2.0, 7.0, 1.0,
8.0, 2.0, 1.0, 112.0;


// fvar<var> - fvar<var>
res = trace_quad_form(ad,bd);
EXPECT_NO_THROW(check_not_nan("trace_quad_form","res",res));
}

TEST(AgradMixMatrixTraceQuadForm, mat_fv_2nd_deriv) {
using stan::math::trace_quad_form;
Expand Down