// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/svm_threaded.h>
#include <dlib/statistics.h>
#include <vector>
#include <sstream>
namespace  
{
    using namespace test;
    using namespace dlib;
    using namespace std;
    dlib::logger dlog("test.one_vs_one_trainer");
    class test_one_vs_one_trainer : public tester
    {
        /*!
            WHAT THIS OBJECT REPRESENTS
                This object represents a unit test.  When it is constructed
                it adds itself into the testing framework.
        !*/
    public:
        test_one_vs_one_trainer (
        ) :
            tester (
                "test_one_vs_one_trainer",       // the command line argument name for this test
                "Run tests on the one_vs_one_trainer stuff.", // the command line argument description
                0                     // the number of command line arguments for this test
            )
        {
        }
        template <typename sample_type, typename label_type>
        void generate_data (
            std::vector<sample_type>& samples,
            std::vector<label_type>& labels
        )
        {
            const long num = 50;
            sample_type m;
            dlib::rand rnd;
            // make some samples near the origin
            double radius = 0.5;
            for (long i = 0; i < num+10; ++i)
            {
                double sign = 1;
                if (rnd.get_random_double() < 0.5)
                    sign = -1;
                m(0) = 2*radius*rnd.get_random_double()-radius;
                m(1) = sign*sqrt(radius*radius - m(0)*m(0));
                // add this sample to our set of samples we will run k-means 
                samples.push_back(m);
                labels.push_back(1);
            }
            // make some samples in a circle around the origin but far away
            radius = 10.0;
            for (long i = 0; i < num+20; ++i)
            {
                double sign = 1;
                if (rnd.get_random_double() < 0.5)
                    sign = -1;
                m(0) = 2*radius*rnd.get_random_double()-radius;
                m(1) = sign*sqrt(radius*radius - m(0)*m(0));
                // add this sample to our set of samples we will run k-means 
                samples.push_back(m);
                labels.push_back(2);
            }
            // make some samples in a circle around the point (25,25) 
            radius = 4.0;
            for (long i = 0; i < num+30; ++i)
            {
                double sign = 1;
                if (rnd.get_random_double() < 0.5)
                    sign = -1;
                m(0) = 2*radius*rnd.get_random_double()-radius;
                m(1) = sign*sqrt(radius*radius - m(0)*m(0));
                // translate this point away from the origin
                m(0) += 25;
                m(1) += 25;
                // add this sample to our set of samples we will run k-means 
                samples.push_back(m);
                labels.push_back(3);
            }
        }
        template <typename label_type, typename scalar_type>
        void run_test (
        )
        {
            print_spinner();
            typedef matrix<scalar_type,2,1> sample_type;
            std::vector<sample_type> samples, norm_samples;
            std::vector<label_type> labels;
            // First, get our labeled set of training data
            generate_data(samples, labels);
            typedef one_vs_one_trainer<any_trainer<sample_type,scalar_type>,label_type > ovo_trainer;
            ovo_trainer trainer;
            typedef histogram_intersection_kernel<sample_type> hist_kernel;
            typedef radial_basis_kernel<sample_type> rbf_kernel;
            // make the binary trainers and set some parameters
            krr_trainer<rbf_kernel> rbf_trainer;
            svm_nu_trainer<hist_kernel> hist_trainer;
            rbf_trainer.set_kernel(rbf_kernel(0.1));
            trainer.set_trainer(rbf_trainer);
            trainer.set_trainer(hist_trainer, 1, 2);
            randomize_samples(samples, labels);
            matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
            print_spinner();
            matrix<scalar_type> ans(3,3);
            ans = 60,  0,  0, 
                  0, 70,  0, 
                  0,  0, 80;
            DLIB_TEST_MSG(ans == res, "res: \n" << res);
            // test using a normalized_function with a one_vs_one_decision_function 
            {
                trainer.set_trainer(hist_trainer, 1, 2);
                vector_normalizer<sample_type> normalizer;
                normalizer.train(samples);
                for (unsigned long i = 0; i < samples.size(); ++i)
                    norm_samples.push_back(normalizer(samples[i]));
                normalized_function<one_vs_one_decision_function<ovo_trainer> > ndf;
                ndf.function = trainer.train(norm_samples, labels);
                ndf.normalizer = normalizer;
                DLIB_TEST(ndf(samples[0])  == labels[0]);
                DLIB_TEST(ndf(samples[40])  == labels[40]);
                DLIB_TEST(ndf(samples[90])  == labels[90]);
                DLIB_TEST(ndf(samples[120])  == labels[120]);
                trainer.set_trainer(hist_trainer, 1, 2);
                print_spinner();
            }
            one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
            DLIB_TEST(df.number_of_classes() == 3);
            DLIB_TEST(df(samples[0])  == labels[0]);
            DLIB_TEST(df(samples[90])  == labels[90]);
            one_vs_one_decision_function<ovo_trainer, 
                decision_function<hist_kernel>,  // This is the output of the hist_trainer
                decision_function<rbf_kernel>    // This is the output of the rbf_trainer
            > df2, df3;
            df2 = df;
            ofstream fout("df.dat", ios::binary);
            serialize(df2, fout);
            fout.close();
            // load the function back in from disk and store it in df3.  
            ifstream fin("df.dat", ios::binary);
            deserialize(df3, fin);
            DLIB_TEST(df3(samples[0])  == labels[0]);
            DLIB_TEST(df3(samples[90])  == labels[90]);
            res = test_multiclass_decision_function(df3, samples, labels);
            DLIB_TEST(res == ans);
        }
        void perform_test (
        )
        {
            dlog << LINFO << "run_test<double,double>()";
            run_test<double,double>();
            dlog << LINFO << "run_test<int,double>()";
            run_test<int,double>();
            dlog << LINFO << "run_test<double,float>()";
            run_test<double,float>();
            dlog << LINFO << "run_test<int,float>()";
            run_test<int,float>();
        }
    };
    test_one_vs_one_trainer a;
}