// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
#define DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
#include "structural_sequence_segmentation_trainer_abstract.h"
#include "structural_sequence_labeling_trainer.h"
#include "sequence_segmenter.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class structural_sequence_segmentation_trainer
{
public:
typedef typename feature_extractor::sequence_type sample_sequence_type;
typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
typedef sequence_segmenter<feature_extractor> trained_function_type;
explicit structural_sequence_segmentation_trainer (
const feature_extractor& fe_
) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_))
{
loss_per_missed_segment = 1;
loss_per_false_alarm = 1;
}
structural_sequence_segmentation_trainer (
)
{
loss_per_missed_segment = 1;
loss_per_false_alarm = 1;
}
const feature_extractor& get_feature_extractor (
) const { return trainer.get_feature_extractor().fe; }
void set_num_threads (
unsigned long num
)
{
trainer.set_num_threads(num);
}
unsigned long get_num_threads (
) const
{
return trainer.get_num_threads();
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void structural_sequence_segmentation_trainer::set_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
trainer.set_epsilon(eps_);
}
double get_epsilon (
) const { return trainer.get_epsilon(); }
unsigned long get_max_iterations (
) const { return trainer.get_max_iterations(); }
void set_max_iterations (
unsigned long max_iter
)
{
trainer.set_max_iterations(max_iter);
}
void set_max_cache_size (
unsigned long max_size
)
{
trainer.set_max_cache_size(max_size);
}
unsigned long get_max_cache_size (
) const
{
return trainer.get_max_cache_size();
}
void be_verbose (
)
{
trainer.be_verbose();
}
void be_quiet (
)
{
trainer.be_quiet();
}
void set_oca (
const oca& item
)
{
trainer.set_oca(item);
}
const oca get_oca (
) const
{
return trainer.get_oca();
}
void set_c (
double C_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(C_ > 0,
"\t void structural_sequence_segmentation_trainer::set_c()"
<< "\n\t C_ must be greater than 0"
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
trainer.set_c(C_);
}
double get_c (
) const
{
return trainer.get_c();
}
void set_loss_per_missed_segment (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_segment = loss;
if (feature_extractor::use_BIO_model)
{
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
}
else
{
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
trainer.set_loss(impl_ss::LAST, loss_per_missed_segment);
trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment);
}
}
double get_loss_per_missed_segment (
) const
{
return loss_per_missed_segment;
}
void set_loss_per_false_alarm (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_alarm = loss;
trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm);
}
double get_loss_per_false_alarm (
) const
{
return loss_per_false_alarm;
}
const sequence_segmenter<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<segmented_sequence_type>& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_sequence_segmentation_problem(x,y) == true,
"\t sequence_segmenter structural_sequence_segmentation_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size()
<< "\n\t is_sequence_segmentation_problem(x,y): " << is_sequence_segmentation_problem(x,y)
<< "\n\t this: " << this
);
std::vector<std::vector<unsigned long> > labels(y.size());
if (feature_extractor::use_BIO_model)
{
// convert y into tagged BIO labels
for (unsigned long i = 0; i < labels.size(); ++i)
{
labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
for (unsigned long j = 0; j < y[i].size(); ++j)
{
const unsigned long begin = y[i][j].first;
const unsigned long end = y[i][j].second;
if (begin != end)
{
labels[i][begin] = impl_ss::BEGIN;
for (unsigned long k = begin+1; k < end; ++k)
labels[i][k] = impl_ss::INSIDE;
}
}
}
}
else
{
// convert y into tagged BILOU labels
for (unsigned long i = 0; i < labels.size(); ++i)
{
labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
for (unsigned long j = 0; j < y[i].size(); ++j)
{
const unsigned long begin = y[i][j].first;
const unsigned long end = y[i][j].second;
if (begin != end)
{
if (begin+1==end)
{
labels[i][begin] = impl_ss::UNIT;
}
else
{
labels[i][begin] = impl_ss::BEGIN;
for (unsigned long k = begin+1; k+1 < end; ++k)
labels[i][k] = impl_ss::INSIDE;
labels[i][end-1] = impl_ss::LAST;
}
}
}
}
}
sequence_labeler<impl_ss::feature_extractor<feature_extractor> > temp;
temp = trainer.train(x, labels);
return sequence_segmenter<feature_extractor>(temp.get_weights(), trainer.get_feature_extractor().fe);
}
private:
structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer;
double loss_per_missed_segment;
double loss_per_false_alarm;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_