Hardware components | ||||||
| × | 1 | ||||
| × | 1 | ||||
| × | 1 | ||||
| × | 2 | ||||
Software apps and online services | ||||||
| ||||||
| ||||||
|
With a high increase in the global population levels, the demand for resources mainly food that is highly necessary to sustain life is on the rise and is the biggest need of the future. The current outputs generated by the agricultural industries and lands is highly not sufficient to feed such huge numbers of people. To obtain this, we will have to increase the efficiency of the existing farming practices with the help of technology and this is exactly what we plan to do.
There are high suicidal rates among farmers in many places like India because of the crop destruction that happens due to certain plant diseases. This could be avoided and we could save the lives of many farmers as well as keep contributing to the global food supply if there was a well implemented method to keep check on each and every plant in the field. This is important as a single affected plant can still have the potential to spoil the whole field.
Ichigo which is a robotic rover that can go on almost all field terrains can help get the status of the entire field in terms of the plant health regarding the contraction of any plant diseases by the plant. The rover can be controlled from outside the field using a controller that also helps you click the picture of every desired leaf and later these images are then fed to a convolutional neural network to predict if the plant is affected by a disease or not. This would considerably revolutionise the lives of the farmers and could increase the efficiency of farming. The rover is also added with extra functionalities like giving the location of the diseased plant using GPS data and not destroying the crops using obstacle avoidance which is implemented through an ultrasonic sensor and is also equipped with a robotic arm that helps you cut off the weeds.
Ichigo in ActionOverview
This project mainly has three major parts to it, the Sony playstation controller that sends the required signal, the computer which receives those signals and transmit them to the rover and also performs the image processing and the third part being the rover which follows the commands sent by the computer. The technical details for each of these parts are explained in detail in the next section.
Setting up the robot
- Connect the motors of the rover and the robotic arm to the motor driver and configure the pins accordingly on the arduino IDE. There is ample material on the web to refer regarding the connections of motors to motor drivers.
- Connect the motor drivers to the Sony board
- Connect the Sony camera module and the ultrasonic sensor to the Sony board and configure them on the arudino IDE
- Connect the Bluetooth module to the board for communicating via bluetooth
- Connect an appropriate power source to the motor drivers of the rover and the robotic rover
Setting up the Sony PlayStation Controller
Initially the PlayStation controller is paired to the computer via bluetooth and we make use of DS4 tools to establish the working of the controller with the windows operating system. The tool also helps us to map the keys and movements in the controller to specific keys on the keyboard. This is how we will be sending key presses to the robot.
Setting up the Computer
The computer is responsible for sending the signals to the bluetooth module in the robot using Processing. The code is given below in the code section. The computer also performs the image processing using Convolutional Neural Networks. We have made use of the nnabla libraries by Sony to perform the image processing. The implemented CNN is a Resnet model which is trained using a dataset (PlantVillage Dataset) with a collection of diseased and healthy leaf samples for specific types of plants.
The dataset used can be found here: https://github.com/spMohanty/PlantVillage-Dataset
The nnabla libraries are very well documented and the examples can be found in this link: https://github.com/sony/nnabla/blob/master/tutorial/by_examples.ipynb
The pictures are taken by the robot on the press of a particular button on the controller and this picture is stored in the sd card of the robot which after the whole journey throught the field can be uploaded to the computer which can later process the image and find out the diseased plants.
Processing code
C/C++import processing.serial.*;
Serial myPort; // Create object from Serial class
void setup()
{
size(200,200); //make our canvas 200 x 200 pixels big
println(Serial.list());
String portName = Serial.list()[0]; //change the 0 to a 1 or 2 etc. to match your port
myPort = new Serial(this, portName, 9600);
}
void draw()
{
int count=0;
if (keyPressed == true)
{
if (key=='d')
{ myPort.write("d");
println(key);
delay(100);
}
if (key=='s')
{ myPort.write("s");
println(key);
delay(100);
}
if (key=='a')
{ myPort.write("a");
println(key);
delay(100);
}
if (key=='w')
{ myPort.write("w");
println(key);
delay(100);
}
if (key=='i')
{ myPort.write("i");
println(key);
delay(100);
}
if (key=='k')
{ myPort.write("k");
println(key);
delay(100);
}
if (key=='j')
{ myPort.write("j");
println(key);
delay(100);
}
if (key=='l')
{ myPort.write("l");
println(key);
delay(100);
}
if (key=='c')
{ myPort.write("c");
println(key);
delay(100);
}
}
else
println("__INVALID/NO INPUT__");
}
Sony Board Code - Camera, GPS, and Ultrasonic Sensor
C/C++#include <SDHCI.h>
#include <stdio.h> /* for sprintf */
#include <Camera.h>
#include <GNSS.h>
#define STRING_BUFFER_SIZE 128 /**< %Buffer size */
#define RESTART_CYCLE (60 * 5) /**< positioning test term */
static SpGnss Gnss; /**< SpGnss object */
#define BAUDRATE (9600)
/**
* @enum ParamSat
* @brief Satellite system
*/
enum ParamSat {
eSatGps, /**< GPS World wide coverage */
eSatGlonass, /**< GLONASS World wide coverage */
eSatGpsSbas, /**< GPS+SBAS North America */
eSatGpsGlonass, /**< GPS+Glonass World wide coverage */
eSatGpsQz1c, /**< GPS+QZSS_L1CA East Asia & Oceania */
eSatGpsGlonassQz1c, /**< GPS+Glonass+QZSS_L1CA East Asia & Oceania */
eSatGpsQz1cQz1S, /**< GPS+QZSS_L1CA+QZSS_L1S Japan */
};
/* Set this parameter depending on your current region. */
static enum ParamSat satType = eSatGps;
/**
* @brief Turn on / off the LED0 for CPU active notification.
*/
static void Led_isActive(void)
{
static int state = 1;
if (state == 1)
{
ledOn(PIN_LED0);
state = 0;
}
else
{
ledOff(PIN_LED0);
state = 1;
}
}
/**
* @brief Turn on / off the LED1 for positioning state notification.
*
* @param [in] state Positioning state
*/
static void Led_isPosfix(bool state)
{
if (state)
{
ledOn(PIN_LED1);
}
else
{
ledOff(PIN_LED1);
}
}
/**
* @brief Turn on / off the LED3 for error notification.
*
* @param [in] state Error state
*/
static void Led_isError(bool state)
{
if (state)
{
ledOn(PIN_LED3);
}
else
{
ledOff(PIN_LED3);
}
}
/**
* @brief Activate GNSS device and start positioning.
*/
SDClass theSD;
int take_picture_count = 0;
/**
* Callback from Camera library when video frame is captured.
*/
void CamCB(CamImage img)
{
/* Check the img instance is available or not. */
if (img.isAvailable())
{
/* If you want RGB565 data, convert image data format to RGB565 */
img.convertPixFormat(CAM_IMAGE_PIX_FMT_RGB565);
/* You can use image data directly by using getImgSize() and getImgBuff().
* for displaying image to a display, etc. */
Serial.print("Image data size = ");
Serial.print(img.getImgSize(), DEC);
Serial.print(" , ");
Serial.print("buff addr = ");
Serial.print((unsigned long)img.getImgBuff(), HEX);
Serial.println("");
}
else
{
Serial.print("Failed to get video stream image\n");
}
}
/**
* @brief Initialize camera
*/
//Motor A left motor
const int motorPin1 = 9; // Pin 14 of L293
const int motorPin2 = 10; // Pin 10 of L293
//Motor B right motor
const int motorPin3 = 8; // Pin 7 of L293
const int motorPin4 = 7; // Pin 2 of L293
//Arm Motor 1
const int motorPin5 = 5;
const int motorPin6 = 6;
//Arm Motor 2
const int motorPin7 = 3;
const int motorPin8 = 4;
char val;
// defines pins numbers for ultrasonic sensor
const int trigPin = 12;
const int echoPin = 13;
// defines variables for ultrasonic sensor
long duration;
int distance;
void setup() {
/* put your setup code here, to run once: */
int error_flag = 0;
/* Wait HW initialization done. */
sleep(3);
/* Turn on all LED:Setup start. */
ledOn(PIN_LED0);
ledOn(PIN_LED1);
ledOn(PIN_LED2);
ledOn(PIN_LED3);
/* Set Debug mode to Info */
Gnss.setDebugMode(PrintInfo);
int result;
/* Activate GNSS device */
result = Gnss.begin();
if (result != 0)
{
Serial.println("Gnss begin error!!");
error_flag = 1;
}
else
{
/* Setup GNSS
* It is possible to setup up to two GNSS satellites systems.
* Depending on your location you can improve your accuracy by selecting different GNSS system than the GPS system.
* See: https://developer.sony.com/develop/spresense/developer-tools/get-started-using-nuttx/nuttx-developer-guide#_gnss
* for detailed information.
*/
switch (satType)
{
case eSatGps:
Gnss.select(GPS);
break;
case eSatGpsSbas:
Gnss.select(GPS);
Gnss.select(SBAS);
break;
case eSatGlonass:
Gnss.select(GLONASS);
break;
case eSatGpsGlonass:
Gnss.select(GPS);
Gnss.select(GLONASS);
break;
case eSatGpsQz1c:
Gnss.select(GPS);
Gnss.select(QZ_L1CA);
break;
case eSatGpsQz1cQz1S:
Gnss.select(GPS);
Gnss.select(QZ_L1CA);
Gnss.select(QZ_L1S);
break;
case eSatGpsGlonassQz1c:
default:
Gnss.select(GPS);
Gnss.select(GLONASS);
Gnss.select(QZ_L1CA);
break;
}
/* Start positioning */
result = Gnss.start(COLD_START);
if (result != 0)
{
Serial.println("Gnss start error!!");
error_flag = 1;
}
else
{
Serial.println("Gnss setup OK");
}
}
/* Turn off all LED:Setup done. */
ledOff(PIN_LED0);
ledOff(PIN_LED1);
ledOff(PIN_LED2);
ledOff(PIN_LED3);
/* Set error LED. */
if (error_flag == 1)
{
Led_isError(true);
exit(0);
}
/* Open serial communications and wait for port to open */
Serial.begin(BAUDRATE);
while (!Serial)
{
; /* wait for serial port to connect. Needed for native USB port only */
}
/* begin() without parameters means that
* number of buffers = 1, 30FPS, QVGA, YUV 4:2:2 format */
Serial.println("Prepare camera");
theCamera.begin();
/* Start video stream.
* If received video stream data from camera device,
* camera library call CamCB.
*/
Serial.println("Start streaming");
theCamera.startStreaming(true, CamCB);
/* Auto white balance configuration */
Serial.println("Set Auto white balance parameter");
theCamera.setAutoWhiteBalanceMode(CAM_WHITE_BALANCE_DAYLIGHT);
/* Set parameters about still picture.
* In the following case, QUADVGA and JPEG.
*/
Serial.println("Start streaming");
theCamera.setStillPictureImageFormat(
CAM_IMGSIZE_QUADVGA_H,
CAM_IMGSIZE_QUADVGA_V,
CAM_IMAGE_PIX_FMT_JPG);
pinMode(motorPin1, OUTPUT);
pinMode(motorPin2, OUTPUT);
pinMode(motorPin3, OUTPUT);
pinMode(motorPin4, OUTPUT);
pinMode(motorPin5, OUTPUT);
pinMode(motorPin6, OUTPUT);
pinMode(motorPin7, OUTPUT);
pinMode(motorPin8, OUTPUT);
pinMode(trigPin, OUTPUT); // Sets the trigPin as an Output
pinMode(echoPin, INPUT); // Sets the echoPin as an Input
// Serial.begin(9600);
}
/**
* @brief %Print position information.
*/
static void print_pos(SpNavData *pNavData)
{
char StringBuffer[STRING_BUFFER_SIZE];
/* print time */
snprintf(StringBuffer, STRING_BUFFER_SIZE, "%04d/%02d/%02d ", pNavData->time.year, pNavData->time.month, pNavData->time.day);
Serial.print(StringBuffer);
snprintf(StringBuffer, STRING_BUFFER_SIZE, "%02d:%02d:%02d.%06d, ", pNavData->time.hour, pNavData->time.minute, pNavData->time.sec, pNavData->time.usec);
Serial.print(StringBuffer);
/* print satellites count */
snprintf(StringBuffer, STRING_BUFFER_SIZE, "numSat:%2d, ", pNavData->numSatellites);
Serial.print(StringBuffer);
/* print position data */
if (pNavData->posFixMode == FixInvalid)
{
Serial.print("No-Fix, ");
}
else
{
Serial.print("Fix, ");
}
if (pNavData->posDataExist == 0)
{
Serial.print("No Position");
}
else
{
Serial.print("Lat=");
Serial.print(pNavData->latitude, 6);
Serial.print(", Lon=");
Serial.print(pNavData->longitude, 6);
}
Serial.println("");
}
/**
* @brief %Print satellite condition.
*/
static void print_condition(SpNavData *pNavData)
{
char StringBuffer[STRING_BUFFER_SIZE];
unsigned long cnt;
/* Print satellite count. */
snprintf(StringBuffer, STRING_BUFFER_SIZE, "numSatellites:%2d\n", pNavData->numSatellites);
Serial.print(StringBuffer);
for (cnt = 0; cnt < pNavData->numSatellites; cnt++)
{
const char *pType = "---";
SpSatelliteType sattype = pNavData->getSatelliteType(cnt);
/* Get satellite type. */
/* Keep it to three letters. */
switch (sattype)
{
case GPS:
pType = "GPS";
break;
case GLONASS:
pType = "GLN";
break;
case QZ_L1CA:
pType = "QCA";
break;
case SBAS:
pType = "SBA";
break;
case QZ_L1S:
pType = "Q1S";
break;
default:
pType = "UKN";
break;
}
/* Get print conditions. */
unsigned long Id = pNavData->getSatelliteId(cnt);
unsigned long Elv = pNavData->getSatelliteElevation(cnt);
unsigned long Azm = pNavData->getSatelliteAzimuth(cnt);
float sigLevel = pNavData->getSatelliteSignalLevel(cnt);
/* Print satellite condition. */
snprintf(StringBuffer, STRING_BUFFER_SIZE, "[%2d] Type:%s, Id:%2d, Elv:%2d, Azm:%3d, CN0:", cnt, pType, Id, Elv, Azm );
Serial.print(StringBuffer);
Serial.println(sigLevel, 6);
}
}
/**
* @brief %Print position information and satellite condition.
*
* @details When the loop count reaches the RESTART_CYCLE value, GNSS device is
* restarted.
*/
void loop() {
static int LoopCount = 0;
static int LastPrintMin = 0;
/* Blink LED. */
Led_isActive();
/* Check update. */
if (Gnss.waitUpdate(-1))
{
/* Get NaviData. */
SpNavData NavData;
Gnss.getNavData(&NavData);
/* Set posfix LED. */
bool LedSet = (NavData.posDataExist && (NavData.posFixMode != FixInvalid));
Led_isPosfix(LedSet);
/* Print satellite information every minute. */
if (NavData.time.minute != LastPrintMin)
{
print_condition(&NavData);
LastPrintMin = NavData.time.minute;
}
/* Print position information. */
print_pos(&NavData);
}
else
{
/* Not update. */
Serial.println("data not update");
}
/* Check loop count. */
LoopCount++;
if (LoopCount >= RESTART_CYCLE)
{
int error_flag = 0;
/* Turn off LED0 */
ledOff(PIN_LED0);
/* Set posfix LED. */
Led_isPosfix(false);
/* Restart GNSS. */
if (Gnss.stop() != 0)
{
Serial.println("Gnss stop error!!");
error_flag = 1;
}
else if (Gnss.end() != 0)
{
Serial.println("Gnss end error!!");
error_flag = 1;
}
else
{
Serial.println("Gnss stop OK.");
}
if (Gnss.begin() != 0)
{
Serial.println("Gnss begin error!!");
error_flag = 1;
}
else if (Gnss.start(HOT_START) != 0)
{
Serial.println("Gnss start error!!");
error_flag = 1;
}
else
{
Serial.println("Gnss restart OK.");
}
LoopCount = 0;
/* Set error LED. */
if (error_flag == 1)
{
Led_isError(true);
exit(0);
}
}
digitalWrite(trigPin, LOW);
delayMicroseconds(2);
// Sets the trigPin on HIGH state for 10 micro seconds
digitalWrite(trigPin, HIGH);
delayMicroseconds(10);
digitalWrite(trigPin, LOW);
// Reads the echoPin, returns the sound wave travel time in microseconds
duration = pulseIn(echoPin, HIGH);
// Calculating the distance
distance= duration*0.034/2;
delay(500);
// Prints the distance on the Serial Monitor
Serial.print("Distance: ");
Serial.println(distance);
if (Serial.available())
{
val = Serial.read();
delay(100);
if (val == 's')
{
//Backward code
// Set Motor A backward
digitalWrite(motorPin1, HIGH);
digitalWrite(motorPin2, LOW);
// Set Motor B backward
digitalWrite(motorPin3, HIGH);
digitalWrite(motorPin4, LOW);
Serial.println("Backward");
}
if (val == 'w')
{
//forward code
// Set Motor A forward
digitalWrite(motorPin1, LOW);
digitalWrite(motorPin2, HIGH);
// Set Motor B forward
digitalWrite(motorPin3, LOW);
digitalWrite(motorPin4, HIGH);
// delay(2000);
}
if (val == 'a')
{
//left code
// Set Motor A backward
digitalWrite(motorPin1, HIGH);
digitalWrite(motorPin2, LOW);
// Set Motor B forward
digitalWrite(motorPin3, LOW);
digitalWrite(motorPin4, HIGH);
//delay(2000);
}
if (val == 'd')
{
//right code
// Set Motor A forward
digitalWrite(motorPin1, LOW);
digitalWrite(motorPin2, HIGH);
// Set Motor B backward
digitalWrite(motorPin3, HIGH);
digitalWrite(motorPin4, LOW);
//delay(2000);
}
if (val == 'i')
{
//Arm motor 1 operation 1
digitalWrite(motorPin5, LOW);
digitalWrite(motorPin6, HIGH);
}
if (val == 'k')
{
//Arm motor 1 operation 2
digitalWrite(motorPin5, HIGH);
digitalWrite(motorPin6, LOW);
}
if (val == 'j')
{
//Arm motor 2 operation 1
digitalWrite(motorPin7, HIGH);
digitalWrite(motorPin8, LOW);
}
if (val == 'l')
{
//Arm motor 2 operation 2
digitalWrite(motorPin7, LOW);
digitalWrite(motorPin8, HIGH);
}
if (val == 'c')
{
sleep(1); /* wait for one second to take still picture. */
/* You can change the format of still picture at here also, if you want. */
/* theCamera.setStillPictureImageFormat(
* CAM_IMGSIZE_HD_H,
* CAM_IMGSIZE_HD_V,
* CAM_IMAGE_PIX_FMT_JPG);
*/
/* This sample code can take 100 pictures in every one second from starting. */
if (take_picture_count < 100)
{
/* Take still picture.
* Unlike video stream(startStreaming) , this API wait to receive image data
* from camera device.
*/
Serial.println("call takePicture()");
CamImage img = theCamera.takePicture();
/* Check availability of the img instance. */
/* If any error was occured, the img is not available. */
if (img.isAvailable())
{
/* Create file name */
char filename[16] = {0};
sprintf(filename, "PICT%03d.JPG", take_picture_count);
Serial.print("Save taken picture as ");
Serial.print(filename);
Serial.println("");
/* Save to SD card as the finename */
File myFile = theSD.open(filename, FILE_WRITE);
myFile.write(img.getImgBuff(), img.getImgSize());
myFile.close();
}
take_picture_count++;
}
}
}
else
{
Serial.println("Serial not available");
digitalWrite(motorPin1, LOW);
digitalWrite(motorPin2, LOW);
digitalWrite(motorPin3, LOW);
digitalWrite(motorPin4, LOW);
digitalWrite(motorPin5, LOW);
digitalWrite(motorPin6, LOW);
digitalWrite(motorPin7, LOW);
digitalWrite(motorPin8, LOW);
}
}
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnabla as nn
import nnabla.logger as logger
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solver as S
import nnabla.utils.save as save
from args import get_args
from mnist_data import data_iterator_mnist
import os
def categorical_error(pred, label):
"""
Compute categorical error given score vectors and labels as
numpy.ndarray.
"""
pred_label = pred.argmax(1)
return (pred_label != label.flat).mean()
# Binary Connect Model
def mnist_binary_connect_lenet_prediction(image, test=False):
"""
Construct LeNet for MNIST (BinaryNet version).
"""
with nn.parameter_scope("conv1"):
c1 = PF.binary_connect_convolution(image, 16, (5, 5))
c1 = PF.batch_normalization(c1, batch_stat=not test)
c1 = F.elu(F.average_pooling(c1, (2, 2)))
with nn.parameter_scope("conv2"):
c2 = PF.binary_connect_convolution(c1, 16, (5, 5))
c2 = PF.batch_normalization(c2, batch_stat=not test)
c2 = F.elu(F.average_pooling(c2, (2, 2)))
with nn.parameter_scope("fc3"):
c3 = PF.binary_connect_affine(c2, 50)
c3 = PF.batch_normalization(c3, batch_stat=not test)
c3 = F.elu(c3)
with nn.parameter_scope("fc4"):
c4 = PF.binary_connect_affine(c3, 10)
c4 = PF.batch_normalization(c4, batch_stat=not test)
return c4
def mnist_binary_connect_resnet_prediction(image, test=False):
"""
Construct ResNet for MNIST (BinaryNet version).
"""
def bn(x):
return PF.batch_normalization(x, batch_stat=not test)
def res_unit(x, scope):
C = x.shape[1]
with nn.parameter_scope(scope):
with nn.parameter_scope('conv1'):
h = F.elu(bn(PF.binary_connect_convolution(
x, C / 2, (1, 1), with_bias=False)))
with nn.parameter_scope('conv2'):
h = F.elu(
bn(PF.binary_connect_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False)))
with nn.parameter_scope('conv3'):
h = bn(PF.binary_connect_convolution(
h, C, (1, 1), with_bias=False))
return F.elu(x + h)
# Conv1 --> 64 x 32 x 32
with nn.parameter_scope("conv1"):
c1 = F.elu(
bn(PF.binary_connect_convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False)))
# Conv2 --> 64 x 16 x 16
c2 = F.max_pooling(res_unit(c1, "conv2"), (2, 2))
# Conv3 --> 64 x 8 x 8
c3 = F.max_pooling(res_unit(c2, "conv3"), (2, 2))
# Conv4 --> 64 x 8 x 8
c4 = res_unit(c3, "conv4")
# Conv5 --> 64 x 4 x 4
c5 = F.max_pooling(res_unit(c4, "conv5"), (2, 2))
# Conv5 --> 64 x 4 x 4
c6 = res_unit(c5, "conv6")
pl = F.average_pooling(c6, (4, 4))
with nn.parameter_scope("classifier"):
y = bn(PF.binary_connect_affine(pl, 10))
return y
# Binary Net Model
def mnist_binary_net_lenet_prediction(image, test=False):
"""
Construct LeNet for MNIST (BinaryNet version).
"""
with nn.parameter_scope("conv1"):
c1 = PF.binary_connect_convolution(image, 16, (5, 5))
c1 = PF.batch_normalization(c1, batch_stat=not test)
c1 = F.binary_tanh(F.average_pooling(c1, (2, 2)))
with nn.parameter_scope("conv2"):
c2 = PF.binary_connect_convolution(c1, 16, (5, 5))
c2 = PF.batch_normalization(c2, batch_stat=not test)
c2 = F.binary_tanh(F.average_pooling(c2, (2, 2)))
with nn.parameter_scope("fc3"):
c3 = PF.binary_connect_affine(c2, 50)
c3 = PF.batch_normalization(c3, batch_stat=not test)
c3 = F.binary_tanh(c3)
with nn.parameter_scope("fc4"):
c4 = PF.binary_connect_affine(c3, 10)
c4 = PF.batch_normalization(c4, batch_stat=not test)
return c4
def mnist_binary_net_resnet_prediction(image, test=False):
"""
Construct ResNet for MNIST (BinaryNet version).
"""
def bn(x):
return PF.batch_normalization(x, batch_stat=not test)
def res_unit(x, scope):
C = x.shape[1]
with nn.parameter_scope(scope):
with nn.parameter_scope('conv1'):
h = F.binary_tanh(bn(PF.binary_connect_convolution(
x, C / 2, (1, 1), with_bias=False)))
with nn.parameter_scope('conv2'):
h = F.binary_tanh(
bn(PF.binary_connect_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False)))
with nn.parameter_scope('conv3'):
h = bn(PF.binary_connect_convolution(
h, C, (1, 1), with_bias=False))
return F.binary_tanh(x + h)
# Conv1 --> 64 x 32 x 32
with nn.parameter_scope("conv1"):
c1 = F.binary_tanh(
bn(PF.binary_connect_convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False)))
# Conv2 --> 64 x 16 x 16
c2 = F.max_pooling(res_unit(c1, "conv2"), (2, 2))
# Conv3 --> 64 x 8 x 8
c3 = F.max_pooling(res_unit(c2, "conv3"), (2, 2))
# Conv4 --> 64 x 8 x 8
c4 = res_unit(c3, "conv4")
# Conv5 --> 64 x 4 x 4
c5 = F.max_pooling(res_unit(c4, "conv5"), (2, 2))
# Conv5 --> 64 x 4 x 4
c6 = res_unit(c5, "conv6")
pl = F.average_pooling(c6, (4, 4))
with nn.parameter_scope("classifier"):
y = bn(PF.binary_connect_affine(pl, 10))
return y
# Binary Weight Model
def mnist_binary_weight_lenet_prediction(image, test=False):
"""
Construct LeNet for MNIST (Binary Weight Network version).
"""
with nn.parameter_scope("conv1"):
c1 = PF.binary_weight_convolution(image, 16, (5, 5))
c1 = F.elu(F.average_pooling(c1, (2, 2)))
with nn.parameter_scope("conv2"):
c2 = PF.binary_weight_convolution(c1, 16, (5, 5))
c2 = F.elu(F.average_pooling(c2, (2, 2)))
with nn.parameter_scope("fc3"):
c3 = F.elu(PF.binary_weight_affine(c2, 50))
with nn.parameter_scope("fc4"):
c4 = PF.binary_weight_affine(c3, 10)
return c4
def mnist_binary_weight_resnet_prediction(image, test=False):
"""
Construct ResNet for MNIST (Binary Weight Network version).
"""
def bn(x):
return PF.batch_normalization(x, batch_stat=not test)
def res_unit(x, scope):
C = x.shape[1]
with nn.parameter_scope(scope):
with nn.parameter_scope('conv1'):
h = F.elu(bn(PF.binary_weight_convolution(
x, C / 2, (1, 1), with_bias=False)))
with nn.parameter_scope('conv2'):
h = F.elu(
bn(PF.binary_weight_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False)))
with nn.parameter_scope('conv3'):
h = bn(PF.binary_weight_convolution(
h, C, (1, 1), with_bias=False))
return F.elu(x + h)
# Conv1 --> 64 x 32 x 32
with nn.parameter_scope("conv1"):
c1 = F.elu(
bn(PF.binary_weight_convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False)))
# Conv2 --> 64 x 16 x 16
c2 = F.max_pooling(res_unit(c1, "conv2"), (2, 2))
# Conv3 --> 64 x 8 x 8
c3 = F.max_pooling(res_unit(c2, "conv3"), (2, 2))
# Conv4 --> 64 x 8 x 8
c4 = res_unit(c3, "conv4")
# Conv5 --> 64 x 4 x 4
c5 = F.max_pooling(res_unit(c4, "conv5"), (2, 2))
# Conv5 --> 64 x 4 x 4
c6 = res_unit(c5, "conv6")
pl = F.average_pooling(c6, (4, 4))
with nn.parameter_scope("classifier"):
y = PF.binary_weight_affine(pl, 10)
return y
def train():
"""
Main script.
Steps:
* Parse command line arguments.
* Specify a context for computation.
* Initialize DataIterator for MNIST.
* Construct a computation graph for training and validation.
* Initialize a solver and set parameter variables to it.
* Create monitor instances for saving and displaying training stats.
* Training loop
* Computate error rate for validation data (periodically)
* Get a next minibatch.
* Set parameter gradients zero
* Execute forwardprop on the training graph.
* Execute backprop.
* Solver updates parameters by using gradients computed by backprop.
* Compute training error
"""
args = get_args(monitor_path='tmp.monitor.bnn')
# Get context.
from nnabla.ext_utils import get_extension_context
logger.info("Running in %s" % args.context)
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
# Initialize DataIterator for MNIST.
data = data_iterator_mnist(args.batch_size, True)
vdata = data_iterator_mnist(args.batch_size, False)
# Create CNN network for both training and testing.
mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
if args.net == 'bincon':
mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
elif args.net == 'binnet':
mnist_cnn_prediction = mnist_binary_net_lenet_prediction
elif args.net == 'bwn':
mnist_cnn_prediction = mnist_binary_weight_lenet_prediction
elif args.net == 'bincon_resnet':
mnist_cnn_prediction = mnist_binary_connect_resnet_prediction
elif args.net == 'binnet_resnet':
mnist_cnn_prediction = mnist_binary_net_resnet_prediction
elif args.net == 'bwn_resnet':
mnist_cnn_prediction = mnist_binary_weight_resnet_prediction
# TRAIN
# Create input variables.
image = nn.Variable([args.batch_size, 1, 28, 28])
label = nn.Variable([args.batch_size, 1])
# Create prediction graph.
pred = mnist_cnn_prediction(image / 255, test=False)
pred.persistent = True
# Create loss function.
loss = F.mean(F.softmax_cross_entropy(pred, label))
# TEST
# Create input variables.
vimage = nn.Variable([args.batch_size, 1, 28, 28])
vlabel = nn.Variable([args.batch_size, 1])
# Create prediction graph.
vpred = mnist_cnn_prediction(vimage / 255, test=True)
# Create Solver.
solver = S.Adam(args.learning_rate)
solver.set_parameters(nn.get_parameters())
# Create monitor.
import nnabla.monitor as M
monitor = M.Monitor(args.monitor_path)
monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100)
monitor_verr = M.MonitorSeries("Test error", monitor, interval=10)
# Training loop.
for i in range(args.max_iter):
if i % args.val_interval == 0:
# Validation
ve = 0.0
for j in range(args.val_iter):
vimage.d, vlabel.d = vdata.next()
vpred.forward(clear_buffer=True)
ve += categorical_error(vpred.d, vlabel.d)
monitor_verr.add(i, ve / args.val_iter)
if i % args.model_save_interval == 0:
nn.save_parameters(os.path.join(
args.model_save_path, 'params_%06d.h5' % i))
# Training forward
image.d, label.d = data.next()
solver.zero_grad()
loss.forward(clear_no_need_grad=True)
# Training backward & update
loss.backward(clear_buffer=True)
solver.weight_decay(args.weight_decay)
solver.update()
# Monitor
e = categorical_error(pred.d, label.d)
monitor_loss.add(i, loss.d.copy())
monitor_err.add(i, e)
monitor_time.add(i)
parameter_file = os.path.join(
args.model_save_path, 'params_%06d.h5' % args.max_iter)
nn.save_parameters(parameter_file)
if __name__ == '__main__':
train()
Comments