// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_RLS_FiLTER_Hh_
#define DLIB_RLS_FiLTER_Hh_
#include "rls_filter_abstract.h"
#include "../svm/rls.h"
#include <vector>
#include "../matrix.h"
#include "../sliding_buffer.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class rls_filter
{
/*!
CONVENTION
- data.size() == the number of variables in a measurement
- data[i].size() == data[j].size() for all i and j.
- data[i].size() == get_window_size()
- data[i][0] == most recent measurement of i-th variable given to update.
- data[i].back() == oldest measurement of i-th variable given to update
(or zero if we haven't seen this much data yet).
- if (count <= 2) then
- count == number of times update(z) has been called
!*/
public:
rls_filter()
{
size = 5;
count = 0;
filter = rls(0.8, 100);
}
explicit rls_filter (
unsigned long size_,
double forget_factor = 0.8,
double C = 100
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 &&
0 < C && size_ >= 2,
"\t rls_filter::rls_filter()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t forget_factor: " << forget_factor
<< "\n\t C: " << C
<< "\n\t size_: " << size_
<< "\n\t this: " << this
);
size = size_;
count = 0;
filter = rls(forget_factor, C);
}
double get_c(
) const
{
return filter.get_c();
}
double get_forget_factor(
) const
{
return filter.get_forget_factor();
}
unsigned long get_window_size (
) const
{
return size;
}
void update (
)
{
if (filter.get_w().size() == 0)
return;
for (unsigned long i = 0; i < data.size(); ++i)
{
// Put old predicted value into the circular buffer as if it was
// the measurement we just observed. But don't update the rls filter.
data[i].push_front(next(i));
}
// predict next state
for (long i = 0; i < next.size(); ++i)
next(i) = filter(mat(data[i]));
}
template <typename EXP>
void update (
const matrix_exp<EXP>& z
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_col_vector(z) == true &&
z.size() != 0 &&
(get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()),
"\t void rls_filter::update(z)"
<< "\n\t invalid arguments were given to this function"
<< "\n\t is_col_vector(z): " << is_col_vector(z)
<< "\n\t z.size(): " << z.size()
<< "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size()
<< "\n\t this: " << this
);
// initialize data if necessary
if (data.size() == 0)
{
data.resize(z.size());
for (long i = 0; i < z.size(); ++i)
data[i].assign(size, 0);
}
for (unsigned long i = 0; i < data.size(); ++i)
{
// Once there is some stuff in the circular buffer, start
// showing it to the rls filter so it can do its thing.
if (count >= 2)
{
filter.train(mat(data[i]), z(i));
}
// keep track of the measurements in our circular buffer
data[i].push_front(z(i));
}
// Don't bother with the filter until we have seen two samples
if (count >= 2)
{
// predict next state
for (long i = 0; i < z.size(); ++i)
next(i) = filter(mat(data[i]));
}
else
{
// Use current measurement as the next state prediction
// since we don't know any better at this point.
++count;
next = matrix_cast<double>(z);
}
}
const matrix<double,0,1>& get_predicted_next_state(
) const
{
return next;
}
friend inline void serialize(const rls_filter& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.count, out);
serialize(item.size, out);
serialize(item.filter, out);
serialize(item.next, out);
serialize(item.data, out);
}
friend inline void deserialize(rls_filter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object.");
deserialize(item.count, in);
deserialize(item.size, in);
deserialize(item.filter, in);
deserialize(item.next, in);
deserialize(item.data, in);
}
private:
unsigned long count;
unsigned long size;
rls filter;
matrix<double,0,1> next;
std::vector<circular_buffer<double> > data;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_RLS_FiLTER_Hh_