Commit 9955202b authored by Federico Rossi's avatar Federico Rossi

Added new models

parent 3f61e11d
......@@ -15,17 +15,6 @@ if(USE_SERIALIZER)
${project_library_target_name} ${REQUIRED_LIBRARIES})
add_dependencies(gtrsb_tests_type gtrsb_test)
add_executable(gtrsb_train_p16_2 train.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_p16_2
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_p16_2 PRIVATE CNN_USE_POSIT CNN_POS_BITS=16 CNN_EXP_BITS=2 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_trains_type gtrsb_train_p16_2)
add_executable(gtrsb_test_p16_2 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p16_2
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p16_2 PRIVATE CNN_USE_POSIT CNN_POS_BITS=16 CNN_EXP_BITS=2 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p16_2)
add_executable(gtrsb_test_p16_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p16_0
......@@ -33,18 +22,7 @@ if(USE_SERIALIZER)
target_compile_definitions(gtrsb_test_p16_0 PRIVATE CNN_USE_POSIT CNN_POS_BITS=16 CNN_EXP_BITS=0 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p16_0)
add_executable(gtrsb_train_p14_2 train.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_p14_2
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_p14_2 PRIVATE CNN_USE_POSIT CNN_POS_BITS=14 CNN_EXP_BITS=2 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_trains_type gtrsb_train_p14_2)
add_executable(gtrsb_test_p14_2 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p14_2
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p14_2 PRIVATE CNN_USE_POSIT CNN_POS_BITS=14 CNN_EXP_BITS=2 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p14_2)
add_executable(gtrsb_test_p14_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p14_0
${project_library_target_name} ${REQUIRED_LIBRARIES})
......@@ -52,18 +30,7 @@ if(USE_SERIALIZER)
add_dependencies(gtrsb_tests_type gtrsb_test_p14_0)
add_executable(gtrsb_train_p12_2 train.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_p12_2
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_p12_2 PRIVATE CNN_USE_POSIT CNN_POS_BITS=12 CNN_EXP_BITS=2 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_trains_type gtrsb_train_p12_2)
add_executable(gtrsb_test_p12_2 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p12_2
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p12_2 PRIVATE CNN_USE_POSIT CNN_POS_BITS=12 CNN_EXP_BITS=2 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p12_2)
add_executable(gtrsb_test_p12_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p12_0
${project_library_target_name} ${REQUIRED_LIBRARIES})
......@@ -82,81 +49,5 @@ add_executable(gtrsb_test_p12_0 test.cpp ${tiny_dnn_headers})
target_compile_definitions(gtrsb_test_p8_0 PRIVATE CNN_USE_POSIT CNN_POS_BITS=8 CNN_EXP_BITS=0 CNN_POS_STORAGE=int16_t CNN_POS_BACKEND=uint16_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p8_0)
add_executable(gtrsb_test_p7_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p7_0
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p7_0 PRIVATE CNN_USE_POSIT CNN_POS_BITS=7 CNN_EXP_BITS=1 CNN_POS_STORAGE=int8_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p7_0)
add_executable(gtrsb_test_p6_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p6_0
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p6_0 PRIVATE CNN_USE_POSIT CNN_POS_BITS=6 CNN_EXP_BITS=0 CNN_POS_STORAGE=int8_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p6_0)
add_executable(gtrsb_test_p5_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p5_0
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p5_0 PRIVATE CNN_USE_POSIT CNN_POS_BITS=5 CNN_EXP_BITS=0 CNN_POS_STORAGE=int8_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p5_0)
add_executable(gtrsb_test_p4_0 test.cpp ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_p4_0
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_p4_0 PRIVATE CNN_USE_POSIT CNN_POS_BITS=4 CNN_EXP_BITS=0 CNN_POS_STORAGE=int8_t CNN_POS_BACKEND=uint32_t)
add_dependencies(gtrsb_tests_type gtrsb_test_p4_0)
if(TAB8)
add_executable(gtrsb_train_posittab8 train.cpp ${TAB_POSIT_LIB8} ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_posittab8
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_posittab8 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit8)
add_executable(gtrsb_test_posittab8 test.cpp ${TAB_POSIT_LIB8} ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_posittab8
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_posittab8 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit8)
add_dependencies(gtrsb_tests_type gtrsb_test_posittab8)
add_dependencies(gtrsb_trains_type gtrsb_train_posittab8)
endif(TAB8)
if(TAB10)
add_executable(gtrsb_train_posittab10 train.cpp ${TAB_POSIT_LIB10} ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_posittab10
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_posittab10 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit10)
add_executable(gtrsb_test_posittab10 test.cpp ${TAB_POSIT_LIB10} ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_posittab10
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_posittab10 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit10)
add_dependencies(gtrsb_tests_type gtrsb_test_posittab10)
add_dependencies(gtrsb_trains_type gtrsb_train_posittab10)
endif(TAB10)
if(TAB12)
add_executable(gtrsb_train_posittab12 train.cpp ${TAB_POSIT_LIB12} ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_posittab12
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_posittab12 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit12)
add_executable(gtrsb_test_posittab12 test.cpp ${TAB_POSIT_LIB12} ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_posittab12
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_posittab12 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit12)
add_dependencies(gtrsb_tests_type gtrsb_test_posittab12)
add_dependencies(gtrsb_trains_type gtrsb_train_posittab12)
endif(TAB12)
if(TAB14)
add_executable(gtrsb_train_posittab14 train.cpp ${TAB_POSIT_LIB14} ${tiny_dnn_headers})
target_link_libraries(gtrsb_train_posittab14
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_train_posittab14 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit14)
add_executable(gtrsb_test_posittab14 test.cpp ${TAB_POSIT_LIB14} ${tiny_dnn_headers})
target_link_libraries(gtrsb_test_posittab14
${project_library_target_name} ${REQUIRED_LIBRARIES})
target_compile_definitions(gtrsb_test_posittab14 PRIVATE CNN_USE_POSIT CNN_TAB_TYPE=posit14)
add_dependencies(gtrsb_tests_type gtrsb_test_posittab14)
add_dependencies(gtrsb_trains_type gtrsb_train_posittab14)
endif(TAB14)
endif()
\ No newline at end of file
......@@ -32,7 +32,7 @@ static tiny_dnn::network<tiny_dnn::sequential>* construct_net() {
using fc = tiny_dnn::layers::fc;
using conv = tiny_dnn::layers::conv;
using ave_pool = tiny_dnn::layers::ave_pool;
using tanh = tiny_dnn::activation::sigmoid;
using tanh = tiny_dnn::activation::relu;
using smax = tiny_dnn::activation::softmax;
using tiny_dnn::core::connection_table;
using padding = tiny_dnn::padding;
......@@ -62,7 +62,7 @@ int main(int argc, char **argv) {
argv[1] -> train folder
argv[2] -> test file
argv[3] -> train params
argv[4] -> learning rate
*/
//srand(time(NULL));
unsigned int seed = rand();
......@@ -80,10 +80,10 @@ int main(int argc, char **argv) {
std::cout << "Num test images: " << vimg.size() << std::endl;
tiny_dnn::adagrad optimizer;
optimizer.alpha*=2;
NNTrainParams tparams(argv[3],&optimizer,tlab,vlab,timg,vimg);
optimizer.alpha*=atof(argv[4]);
NNTrainParams tparams(argv[3],&optimizer,vlab,vlab,vimg,vimg);
std::cout << tparams;
NNModel model("GTRSB-TANH",*construct_net());
model.train(tparams);
NNModel model("GTRSB-RELU",*construct_net());
model.train<tiny_dnn::mse>(tparams);
model.save();
}
......@@ -26,5 +26,49 @@ namespace tiny_dnn {
return residualBlocks;
}
std::vector<layer*> make_block3(size_t inputWidth,size_t inputHeight,size_t inputChannels,bool sub,bool up=false) {
using sc = tiny_dnn::shortcut_layer;
using conv = tiny_dnn::convolutional_layer;
using relu = tiny_dnn::relu_layer;
std::vector<layer*> residualBlocks;
size_t scale_i = (up)?1:4;
residualBlocks.push_back(new sc(tiny_dnn::ShortcutSide::BEGIN,inputWidth,inputHeight,scale_i*inputChannels));
size_t stride = sub?2:1;
size_t outputChannels = inputChannels*((sub)?2:1);
residualBlocks.push_back( // outputchannels@1x1
new conv(inputWidth,inputHeight,1,scale_i*inputChannels,outputChannels,padding::same,true, stride, stride, stride, stride)
);
residualBlocks.push_back(new relu());
inputWidth/=(sub)?2:1;
inputHeight/=(sub)?2:1;
residualBlocks.push_back( // outputchannels@3x3
new conv(inputWidth,inputHeight,3,outputChannels,outputChannels,padding::same,true, 1, 1, 1, 1)
);
residualBlocks.push_back( // 4*outputchannels@1x1
new conv(inputWidth,inputHeight,1,outputChannels,4*outputChannels,padding::same,true, 1, 1, 1, 1)
);
residualBlocks.push_back(new sc(tiny_dnn::ShortcutSide::END,inputWidth,inputHeight,4*outputChannels));
return residualBlocks;
}
void experimentalBlock(tiny_dnn::network<tiny_dnn::sequential>* nn,size_t w,size_t h,size_t c) {
using add = tiny_dnn::elementwise_add_layer;
using conv = tiny_dnn::convolutional_layer;
using relu = tiny_dnn::relu_layer;
layer* last = (*nn)[nn->layer_size()-1];
*nn << conv(w,h,3,c,c,padding::same,true, 1, 1, 1, 1);
*nn << relu();
*nn << conv(w,h,3,c,c,padding::same,true, 1, 1, 1, 1);
layer* joint = new add(2,w*h*c);
*nn << *joint;
*last << *joint;
}
}
}
......@@ -19,9 +19,10 @@ public:
this->network = network;
};
template <class LossFunction>
void train(NNTrainParams& params) {
NNTrainSession s(&network,&params);
s.start();
s.start<LossFunction>();
}
void save(std::string path) {
......
......@@ -84,8 +84,9 @@ public:
}
template <class LossFunction>
void start() {
_model->train<tiny_dnn::mse>(*(_params->optimizer), _params->timages, _params->tlabels, _params->minibatch_size,
_model->train<LossFunction>(*(_params->optimizer), _params->timages, _params->tlabels, _params->minibatch_size,
_params->epochs,
std::bind(&NNTrainSession::_on_enumerate_minibatch,this),
std::bind(&NNTrainSession::_on_enumerate_epoch,this),
......
/*
Copyright (c) 2013, Taiga Nomi and the respective contributors
All rights reserved.
Use of this source code is governed by a BSD-style license that can be found
in the LICENSE file.
*/
#pragma once
#include <string>
// Based on:
// https://github.com/DeepMark/deepmark/blob/master/torch/image%2Bvideo/alexnet.lua
template <size_t width,size_t height,size_t n_classes>
class resnet152 : public tiny_dnn::network<tiny_dnn::sequential> {
public:
explicit resnet152(const std::string &name = "")
: tiny_dnn::network<tiny_dnn::sequential>(name) {
// todo: (karandesai) shift this to tiny_dnn::activation
using relu = tiny_dnn::activation::relu;
using conv = tiny_dnn::layers::conv;
using fc = tiny_dnn::layers::fc;
using max_pool = tiny_dnn::layers::max_pool;
using ave_pool = tiny_dnn::global_average_pooling_layer;
using sotfmax = tiny_dnn::activation::softmax;
size_t img_width = width,img_height = height;
*this << conv(img_width,img_height,7,7,3,64,padding::same,true, 2, 2, 2, 2) << relu();
*this << max_pool(img_width=img_width/2,img_height=img_height/2,64,3,3,2,2,false,padding::same);
*this << tiny_dnn::residual::make_block3(img_width=img_width/2,img_height=img_height/2,64,false,true);
*this << tiny_dnn::residual::make_block3(img_width,img_height,64,false);
*this << tiny_dnn::residual::make_block3(img_width,img_height,64,true);
*this << tiny_dnn::residual::make_block3(img_width=img_width/2,img_height=img_height/2,128,false);
for(int i = 0; i < 6; ++i)
*this << tiny_dnn::residual::make_block3(img_width,img_height,128,false);
*this << tiny_dnn::residual::make_block3(img_width,img_height,128,true);
*this << tiny_dnn::residual::make_block3(img_width=img_width/2,img_height=img_height/2,256,false);
for(int i = 0; i < 34;++i)
*this << tiny_dnn::residual::make_block3(img_width,img_height,256,false);
*this << tiny_dnn::residual::make_block3(img_width,img_height,256,true);
*this << tiny_dnn::residual::make_block3(img_width=img_width/2,img_height=img_height/2,512,false);
*this << tiny_dnn::residual::make_block3(img_width,img_height,512,false);
*this << tiny_dnn::residual::make_block3(img_width,img_height,512,false);
*this << ave_pool(img_width,img_height,4*512) << fc(4*512,1000) << sotfmax();
}
};
/*
Copyright (c) 2013, Taiga Nomi and the respective contributors
All rights reserved.
Use of this source code is governed by a BSD-style license that can be found
in the LICENSE file.
*/
#pragma once
#include <string>
// Based on:
// https://github.com/DeepMark/deepmark/blob/master/torch/image%2Bvideo/alexnet.lua
template <size_t n_classes>
class vgg19 : public tiny_dnn::network<tiny_dnn::sequential> {
public:
explicit vgg19(const std::string &name = "")
: tiny_dnn::network<tiny_dnn::sequential>(name) {
// todo: (karandesai) shift this to tiny_dnn::activation
using relu = tiny_dnn::activation::relu;
using conv = tiny_dnn::layers::conv;
using fc = tiny_dnn::layers::fc;
using max_pool = tiny_dnn::layers::max_pool;
using sotfmax = tiny_dnn::activation::softmax;
// BLOCK 1
*this << conv(224, 224, 3, 3, 3, 64, padding::same);
*this << relu();
*this << conv(224, 224, 3, 3, 64, 64, padding::same);
*this << relu();
*this << max_pool(224, 224, 64, 2);
// BLOCK 2
*this << conv(112, 112, 3, 3, 64, 128, padding::same);
*this << relu();
*this << conv(112, 112, 3, 3, 128, 128, padding::same);
*this << relu();
*this << max_pool(112, 112, 128, 2);
// BLOCK 3
*this << conv(56, 56, 3, 3, 128, 256, padding::same);
*this << relu();
*this << conv(56, 56, 3, 3, 256, 256, padding::same);
*this << relu();
*this << conv(56, 56, 3, 3, 256, 256, padding::same);
*this << relu();
*this << conv(56, 56, 3, 3, 256, 256, padding::same);
*this << relu();
*this << max_pool(56, 56, 256, 2);
// BLOCK 4
*this << conv(28, 28, 3, 3, 256, 512, padding::same);
*this << relu();
*this << conv(28, 28, 3, 3, 512, 512, padding::same);
*this << relu();
*this << conv(28, 28, 3, 3, 512, 512, padding::same);
*this << relu();
*this << conv(28, 28, 3, 3, 512, 512, padding::same);
*this << relu();
*this << max_pool(28, 28, 512, 2);
// BLOCK 5
*this << conv(14, 14, 3, 3, 512, 512, padding::same);
*this << relu();
*this << conv(14, 14, 3, 3, 512, 512, padding::same);
*this << relu();
*this << conv(14, 14, 3, 3, 512, 512, padding::same);
*this << relu();
*this << conv(14, 14, 3, 3, 512, 512, padding::same);
*this << relu();
*this << max_pool(14, 14, 512, 2);
// FULLY CONNECTED
*this << fc(7*7*512,4096) << relu();
*this << fc(4096,4096) << relu();
*this << fc(4096,n_classes) << sotfmax();
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment