BruceFan's Blog

Stay hungry, stay foolish

0%

syzkaller运行流程结构如下图所示,红色标签表示需要配置的选项:

syz-manager用来启动、监控、和重启多个虚拟机实例,并在虚拟机里启动一个syz-fuzzer进程。它负责持久化corpus和存储crash。
syz-fuzzer在要测试的内核虚拟机上运行,syz-fuzzer指导fuzz进程(产生输入、变异、精简等)并通过RPC方式发送触发新路径的输入返回给syz-manager进程。它也会启动一个暂态syz-executor进程。
每个syz-executor进程执行一个输入(一套syscalls)。如:

1
2
3
4
mmap(&(0x7f000000000),(0x1000), 0x3, 0x32, -1, 0)
r0 = open(&(0x7f0000000000))="./file0", 0x3, 0x9)
read(r0, &(0x7f0000000000), 42)
close(r0)

syz-fuzzer进程接收输入来执行,并将结果返回。它被设计的尽可能简单(为了不干扰fuzz进程),用C++实现,编译为静态二进制,用共享内存通信。

源码分析

先从启动syzkaller的命令行工具syz-manager的源码开始分析,syz-manager的源码位于syz-manager/manager.go文件,首先是一个Manager结构体,里面包含了fuzz过程中的重要信息,如配置信息、虚拟机信息、测试目标信息等,具体内容后面会分析到。
接下来是syz-manager运行的main函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
func main() {
if sys.GitRevision == "" {
log.Fatalf("Bad syz-manager build. Build with make, run bin/syz-manager.")
}
flag.Parse() // 对命令行参数进行解析
log.EnableLogCaching(1000, 1<<20)
cfg, err := mgrconfig.LoadFile(*flagConfig) // flagConfig是命令行解析出来的配置文件,通过mgrconfig的LoadFile读取到cfg中
if err != nil {
log.Fatalf("%v", err)
}
target, err := prog.GetTarget(cfg.TargetOS, cfg.TargetArch) // 根据配置文件里的系统和架构获取目标信息,包括系统调用等,保存在Target结构体中
if err != nil {
log.Fatalf("%v", err)
}
sysTarget := targets.Get(cfg.TargetOS, cfg.TargetArch)
if sysTarget == nil {
log.Fatalf("unsupported OS/arch: %v/%v", cfg.TargetOS, cfg.TargetArch)
}
syscalls, err := mgrconfig.ParseEnabledSyscalls(target, cfg.EnabledSyscalls, cfg.DisabledSyscalls) // 可以去掉一些不感兴趣的syscall
if err != nil {
log.Fatalf("%v", err)
}
RunManager(cfg, target, sysTarget, syscalls)
}

pkg/mgrconfig里的config.go里定义了Config结构体,用来保存配置文件里的信息,load.go文件里定义了读取配置文件的方法。
prog/target.go中定义了Target结构体,GetTarget方法中用到了targets变量,targets变量在RegisterTarget方法中初始化,在RegisterTarget中添加debug.PrintStack(),发现RegisterTarget位于栈底,不知道是哪里调用了它。其实是在manager.go文件开头,import了sys包,在sys/sys.go文件中,import ( _ “**/sys/linux/gen)导入包前的下划线表示这个包里所有文件的init方法都会被执行,sys/linux/gen/里有386.go、amd64.go、arm64.go等,这些文件里都有init方法,init里调用了RegisterTarget方法,初始化了各个目标平台的target信息。GetTarget最后用sync.Once的Do方法确保target的初始化在整个程序(多线程环境)中只执行一次,内部通过互斥锁实现。
sysTarget的用处还没有细看。
接下来是RunManager方法,RunManager实现了启动虚拟机、http服务、rpc服务和log fuzz进程等操作。

1
2
3
4
5
6
7
8
9
func RunManager(cfg *mgrconfig.Config, target *prog.Target, sysTarget *targets.Target, syscalls []int) {
var vmPool *vm.Pool
if cfg.Type != "none" {
var err error
vmPool, err = vm.Create(cfg, *flagDebug) // 首先根据cfg中的虚拟机类型(Type如qemu)对vmPool进行初始化
if err != nil {
log.Fatalf("%v", err)
}
}

在vm/vm.go文件中,还是用import (_ “**/vm/qemu”)的方法,调用了所有导入文件里的init方法,qemu的init方法:

1
2
3
func init() {
vmimpl.Register("qemu", ctor, true)
}

调用了vm/vmimpl/vmimpl.go的Register方法:

1
2
3
4
5
6
func Register(typ string, ctor ctorFunc, allowsOvercommit bool) {
Types[typ] = Type{
Ctor: ctor,
Overcommit: allowsOvercommit,
}
}

再看vm/vm.go的Create方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func Create(cfg *mgrconfig.Config, debug bool) (*Pool, error) {
typ, ok := vmimpl.Types[cfg.Type] // 取出vmimpl中cfg.Type对应的虚拟机类型,如qemu
if !ok {
return nil, fmt.Errorf("unknown instance type '%v'", cfg.Type)
}
env := &vmimpl.Env{
Name: cfg.Name,
OS: cfg.TargetOS,
Arch: cfg.TargetVMArch,
Workdir: cfg.Workdir,
Image: cfg.Image,
SSHKey: cfg.SSHKey,
SSHUser: cfg.SSHUser,
Debug: debug,
Config: cfg.VM,
}
impl, err := typ.Ctor(env) // 用取出的Type类型构建vmimpl.Pool,vmimpl.Pool是interface,在这里typ.Ctor返回的是实现了这个接口的具体的Pool,如qemu的Pool
if err != nil {
return nil, err
}
return &Pool{ // 将Pool返回给RunManager的vmPool变量
impl: impl, // 这里的impl已经是qemu的Pool了
workdir: env.Workdir,
}, nil
}

再回到RunManager,进行mgr := &Manager创建mgr管理信息,initHTTP()创建HTTP服务器,startRPCServer(mgr)为fuzzer创建RPC服务器。接下来的go func() { for log },并发执行一个匿名函数,不停log fuzz进度。go加上方法表示并发执行这个方法。最后RunManager执行一个mgr.vmLoop()方法,vmLoop()方法会调用一个runInstance方法,runInstance调用mgr.vmPool.Create(index),这里是vm/vm.go的Create(),这个Create()会调用vmPool的impl的Create(workdir, index)方法,这里也就是qemu的Create(workdir, index)方法,qemu的Create方法会创建sshkey,并调用ctor方法,ctor方法会调用boot方法,boot方法会执行启动qemu的命令。vm.go的Create方法会返回启动虚拟机的实例,接下来runInstance方法会通过ssh拷贝fuzzerBin和executorBin到虚拟机实例,然后执行fuzzer二进制文件,并监控虚拟机的执行过程。

reference
How syzkaller works

宿主机要通过ssh访问虚拟机有两种网络配置方式,一种是用户模式网络,另一种是网桥网络模式。

qemu内部的用户模式网络

在没有任何-net参数时,qemu默认使用的是-net nic -net user参数,提供了一种用户模式(user-mode)的网络模拟。使用用户模式网络的虚拟机可以连通宿主机及外部网络,用户模式网络是完全由qemu自身实现,不依赖于其他工具,而且不需要root权限。qemu使用Slirp实现了一整套TCP/IP协议栈,并且使用这个协议栈实现了一套虚拟的NAT网络。
优点:

  • 使用简单
  • 独立性好
  • 虚拟机网络隔离性好

缺点:

  • 由于其在qemu内部实现所有网络协议栈,因此其性能较差
  • 不支持部分网络功能(如ICMP),所以不能在虚拟机中使用ping命令
  • 不能从宿主机或外部网络直接访问客户机

命令举例:

1
2
3
4
5
6
7
8
9
10
11
12
$ qemu-system-x86_64 \
-kernel $KERNEL/arch/x86/boot/bzImage \
-append "console=ttyS0 root=/dev/sda debug earlyprintk=serial slub_debug=QUZ" \
-hda $IMAGE/stretch.img \
-net user,hostfwd=tcp::10021-:22 \
-net nic,model=e1000 \
-enable-kvm \
-nographic \
-m 1G \
-smp 2 \
-pidfile vm.pid \
2>&1 | tee vm.log

-net user表示使用的是用户模式,hostfwd=tcp::10021-:22表示将虚拟机22端口转发到宿主机的10021端口。
-net nic表示为虚拟机创建虚拟机网卡,model=e1000表示为虚拟机添加一块e1000型的网卡,这也是qemu的默认类型。
查看qemu支持的NIC类型:

1
2
3
$ qemu-system-x86_64 -net nic,model=?
warning: TCG doesn't support requested feature: CPUID.01H:ECX.vmx [bit 5]
qemu: Supported NIC models: ne2k_pci,i82551,i82557b,i82559er,rtl8139,e1000,pcnet,virtio

这样配置好目录和参数以后启动qemu,有可能会遇到如下错误:

1
[FAILED] Failed to start Raise network interfaces.

这是虚拟机网络配置有问题,使用ip命令查看虚拟机使用的网卡,在虚拟机里执行:

1
2
3
4
5
6
7
8
9
10
11
# ip a s
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
inet6 ::1/128 scope host
valid_lft forever preferred_lft forever
2: enp0s3: <BROADCAST,MULTICAST> mtu 1500 qdisc noop state DOWN group default qlen 1000
link/ether 52:54:00:12:34:56 brd ff:ff:ff:ff:ff:ff
3: sit0@NONE: <NOARP> mtu 1480 qdisc noop state DOWN group default qlen 1000
link/sit 0.0.0.0 brd 0.0.0.0

可以看到有一个enp0s3网卡,查看网络配置文件:

1
2
3
# vi /etc/network/interfaces
auto eth0
iface eth0 inet dhcp

网络配置文件中的网卡和实际的网卡不一样,修改eth0enp0s3,重启网络:

1
2
3
4
5
6
# /etc/init.d/networking restart
Restarting networking (via systemctl): networking.service[ 287.214196] e1000: enp0s3 NIC Link is Up 1000 Mbps Full Duplex, Flow Control: RX
[ 287.227353] IPv6: ADDRCONF(NETDEV_CHANGE): enp0s3: link becomes ready
[ 287.446498] audit: type=1107 audit(1559202014.248:8): pid=1 uid=0 auid=4294967295 ses=4294967295 subj=system_u:system_r:kernel_t:s0 msg='avc: denied { reload } for auid=n/a uid=0 gid=0 path="/lib/systemd/system/ssh.service" cmdline="systemctl reload --no-block ssh.service" scontext=system_u:system_r:kernel_t:s0 tcontext=system_u:object_r:unlabeled_t:s0 tclass=service
[ 287.446498] exe="/lib/systemd/systemd" sauid=0 hostname=? addr=? terminal=?'
.

启动网卡成功,可以在宿主机ssh连接虚拟机了:

1
$ ssh -i identity_file -p 10021 root@localhost

使用网桥模式

网桥(bridge)模式可以让虚拟机和宿主机共享一个物理网络设备连接网络,虚拟机有自己的独立IP地址,可以直接连接与宿主机一样的网络,虚拟机可以访问外部网络,外部网络也可以直接访问客户机(就像访问普通物理主机一样)。即使宿主机只有一个网卡设备,使用bridge方式也可以让多个虚拟机与宿主机共享网络设备。
bridge模式需要添加-net tap参数,表明使用TAP设备。TAP是虚拟网络设备,它仿真了一个数据链路层设备(ISO七层网络结构的第二层),它像以太网的数据帧一样处理第二层数据包。TUN与TAP类似,也是一种虚拟网络设备,它是对网络层的仿真。TAP被用于创建一个网桥,而TUN与路由相关。

第一种linux下tap网络配置方法:

建立网桥:Ubuntu上需要安装建立虚拟网络设备的工具uml-utilities和桥接工具bridge-utils:

1
$ sudo apt install uml-utilities bridge-utils

创建一个bridge

1
$ sudo brctl addbr br0

清空网卡的IP

1
$ sudo ip addr flush dev eth0

添加网卡到bridge

1
$ sudo brctl addif br0 eth0

创建tap接口

1
$ sudo tunctl -t tap0 -u fanrong

添加tap0到bridge

1
$ sudo brctl addif br0 tap0

确定都已启动

1
2
3
$ sudo ifconfig eth0 up
$ sudo ifconfig tap0 up
$ sudo ifconfig br0 up

检查桥接是否恰当

1
$ sudo brctl show

为br0分配ip

1
2
$ brctl stp br0 on # 待定
$ sudo dhclient br0

启动命令

1
$ sudo qemu-system-i386 -cdrom Core-current.iso -boot d -netdev tap,id=mynet0,ifname=tap0,script=no,downscript=no -device e1000,netdev=mynet0,mac=52:55:00:d1:55:01

第二种linux下tap网络配置方法:

1.建立一个虚拟网络接口

1
$ sudo mknod /dev/net/tun c 10 200

2.建立网桥
修改/etc/network/interface配置文件。此处建立一个名为br0的网桥,先桥接上eth0,在启动qemu时,再桥接上tap0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
auto lo
iface lo inet loopback

auto br0
iface br0 inet static
address 192.168.1.2
network 192.168.1.0
netmask 255.255.255.0
broadcast 192.168.1.255
gateway 192.168.1.1
bridge_ports eth0
bridge_fd 9
bridge_hello 2
bridge_maxage 12
bridge_stp off

3.建立qemu-ifup脚本,启动qemu时调用

1
2
3
4
5
6
7
8
$ sudo apt install uml-utilities bridge-utils
$ brctl addbr br0
$ sudo vim /etc/qemu-ifup
#!/bin/sh
sudo /sbin/ifconfig $1 0.0.0.0 promisc up
sudo /sbin/brctl addif br0 $1
sleep 2
$ sudo chmod +x /etc/qemu-ifup

4.qemu启动命令

1
$ -net nic -net tap,ifname=tap0,script=/etc/qemu-ifup 

reference
http://smilejay.com/2016/09/kvm-user-mode-networking/
http://smilejay.com/2012/08/kvm-bridge-networking/

syzkaller官网上有介绍如何在Ubuntu宿主机上用qemu方法fuzz x86_64的Linux内核,但是步骤很分散,在好几个页面上,而且还可能有一些坑,后面会讲到。
首先介绍一下我的环境:

  • Ubuntu 16.04 x86_64
  • gcc 8.2.0
  • linux-5.1
  • go1.12.5

安装新版gcc

Syzkaller是一个coverage-guided fuzzer,因此需要编译内核有coverage support,gcc 6.1.0以后加入了coverage support。这里我是源码编译安装了gcc 8.2.0,大体步骤如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
$ wget http://ftp.tsukuba.wide.ad.jp/software/gcc/releases/gcc-8.2.0/gcc-8.2.0.tar.gz
$ tar xzvf gcc-8.2.0.tar.gz
$ cd gcc-8.2.0/
$ ./contrib/download_prerequisites
$ sudo apt install texinfo bison flex
$ mkdir build
$ cd build
$ ../configure --prefix=/usr/local/gcc --enable-bootstrap --enable-checking=release --enable-languages=c,c++ --disable-multilib
$ make -j8
$ sudo make install
$ vim ~/.bashrc
export PATH=/usr/local/gcc/bin:$PATH
$ source ~/.bashrc
$ gcc -v
...
gcc version 8.2.0 (GCC)

编译新版linux内核

linux kernel也需要coverage support,KCOV在linux kernel 4.6以后加入,可以用CONFIG_KCOV=y配置。
编译内核的大体步骤如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
$ wget http://cdn.kernel.org/pub/linux/kernel/v5.x/linux-5.1.tar.gz
$ tar zxvf linux-5.1.tar.gz
$ cd linux-5.1
$ make defconfig
$ make kvmconfig
$ vim .config
CONFIG_KCOV=y
CONFIG_DEBUG_INFO=y
CONFIG_KASAN=y
CONFIG_KASAN_INLINE=y
CONFIG_CONFIGFS_FS=y
CONFIG_SECURITYFS=y
$ make oldconfig # 使能这些选项使得一些子选项可用,一路回车即可
$ make -j8

创建linux镜像

接着需要创建一个debian-strech的Linux镜像:

1
2
3
4
5
6
7
8
$ sudo apt install debootstrap
$ mkdir IMAGE
$ cd IMAGE
$ wget https://raw.githubusercontent.com/google/syzkaller/master/tools/create-image.sh -O create-image.sh
$ chmod +x create-image.sh
$ ./create-image.sh
$ ls
chroot create-image.sh stretch.id_rsa stretch.id_rsa.pub stretch.img

安装qemu

1
sudo apt install qemu-system-x86

下面需要确保kernel能正常启动,sshd能正常运行,这里是一个坑点:

1
2
3
4
5
6
7
8
9
10
11
$ qemu-system-x86_64 \
-kernel linux-5.1/arch/x86/boot/bzImage \
-append "console=ttyS0 root=/dev/sda debug earlyprintk=serial slub_debug=QUZ"\
-hda IMAGE/stretch.img \
-net user,hostfwd=tcp::10021-:22 -net nic \
-enable-kvm \
-nographic \
-m 2G \
-smp 2 \
-pidfile vm.pid \
2>&1 | tee vm.log

这样启动的qemu可能存在[FAILED] Failed to start Raise network interfaces.的错误,原因是虚拟机启动后网卡名称为enp0s3(ip a s命令可以查看),而/etc/network/interfaces里的默认配置为eth0,网卡名称配置错误,启动不了网络接口,需要在虚拟机里修改interfaces文件。

之后可以在另一个终端ssh连接到虚拟机:

1
ssh -i IMAGE/stretch.id_rsa -p 10021 -o "StrictHostKeyChecking no" root@localhost

安装golang

syzkaller是go语言实现的,要编译安装需要Go 1.11+工具链。我安装的是go1.12.5。将go解压到/usr/local/目录,在~/创建gopath文件夹,在环境变量中添加GOROOT变量和GOPATH等:

1
2
3
4
$ vim ~/.bashrc
export GOROOT=/usr/local/go
export GOPATH=/home/fanrong/gopath
export PATH=$GOROOT/bin:$PATH

编译syzkaller

下载编译syzkaller源码:

1
2
3
4
5
$ go get -u -d github.com/google/syzkaller/...
$ cd $GOPATH/src/github.com/google/syzkaller
$ make
$ ls bin
linux_amd64 syz-db syz-manager syz-mutate syz-prog2c syz-repro syz-runtest syz-upgrade

直接运行make是在amd64平台上编译amd64版本的syzkaller,如果要交叉编译需要设置TARGETOSTARGETARCH等参数,详见Makefile。

启动syzkaller开始fuzz

在syzkaller目录创建一个配置文件my.cfg来指定fuzz中相关的信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{
"target": "linux/amd64",
"http": "127.0.0.1:56741",
"workdir": "workdir",
"kernel_obj": "/home/fanrong/Computer/kernel/linux-5.1",
"image": "/home/fanrong/Computer/kernel/IMAGE/stretch.img",
"sshkey": "/home/fanrong/Computer/kernel/IMAGE/stretch.id_rsa",
"syzkaller": ".",
"procs": 8,
"type": "qemu",
"vm": {
"count": 4,
"kernel": "$KERNEL/arch/x86/boot/bzImage",
"cpu": 2,
"mem": 2048
}
}

然后就可以在syzkaller目录中运行fuzz了,这里会遇到刚才启动qemu那个坑,前面把网卡改成了enp0s3才能正常运行sshd,而这里又出现了网卡不能启动的错误。这里要找出错误原因需要知道syzkaller启动虚拟机时的参数,一开始我打算看源码,在源码中log启动参数,后来发现syz-manager本身提供了调试信息的打印,只要在启动syz-manager的时候加上-debug选项,运行发现启动qemu的命令如下:

1
$ qemu-system-x86_64 -m 1024 -smp 1 -net nic,model=e1000 -net user,host=10.0.2.10,hostfwd=tcp::1569-:22 -display none -serial stdio -no-reboot -enable-kvm -cpu host,migratable=off -hda /home/fanrong/Computer/IoT/IMAGE/stretch.img -snapshot -kernel /home/fanrong/Computer/IoT/linux-5.1/arch/x86/boot/bzImage -append "earlyprintk=serial oops=panic nmi_watchdog=panic panic_on_warn=1 panic=1 ftrace_dump_on_oops=orig_cpu rodata=n vsyscall=native net.ifnames=0 biosdevname=0 root=/dev/sda console=ttyS0 kvm-intel.nested=1 kvm-intel.unrestricted_guest=1 kvm-intel.vmm_exclusive=1 kvm-intel.fasteoi=1 kvm-intel.ept=1 kvm-intel.flexpriority=1 kvm-intel.vpid=1 kvm-intel.emulate_invalid_guest_state=1 kvm-intel.eptad=1 kvm-intel.enable_shadow_vmcs=1 kvm-intel.pml=1 kvm-intel.enable_apicv=1"

单独运行这个命令,发现网卡确实不能启动,在qemu里运行ip a s命令,网卡名变成了eth0,所以又要把/etc/network/interfaces里的网卡名改回eth0,然后就可以正常启动syzkaller了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
$ mkdir workdir
$ ./bin/syz-manager -config=my.cfg
2019/05/30 18:29:01 loading corpus...
2019/05/30 18:29:01 serving http on http://127.0.0.1:56741
2019/05/30 18:29:01 serving rpc on tcp://[::]:42897
2019/05/30 18:29:01 booting test machines...
2019/05/30 18:29:01 wait for the connection from test machine...
2019/05/30 18:29:25 machine check:
2019/05/30 18:29:25 syscalls : 1387/2695
2019/05/30 18:29:25 code coverage : enabled
2019/05/30 18:29:25 comparison tracing : enabled
2019/05/30 18:29:25 extra coverage : extra coverage is not supported by the kernel
2019/05/30 18:29:25 setuid sandbox : enabled
2019/05/30 18:29:25 namespace sandbox : /proc/self/ns/user does not exist
2019/05/30 18:29:25 Android sandbox : enabled
2019/05/30 18:29:25 fault injection : CONFIG_FAULT_INJECTION is not enabled
2019/05/30 18:29:25 leak checking : CONFIG_DEBUG_KMEMLEAK is not enabled
2019/05/30 18:29:25 net packet injection : /dev/net/tun does not exist
2019/05/30 18:29:25 net device setup : enabled
2019/05/30 18:29:25 corpus : 0 (0 deleted)
2019/05/30 18:29:31 VMs 1, executed 0, cover 1349, crashes 0, repro 0
2019/05/30 18:29:41 VMs 1, executed 595, cover 2456, crashes 0, repro 0
2019/05/30 18:29:51 VMs 1, executed 1156, cover 3520, crashes 0, repro 0
2019/05/30 18:30:01 VMs 1, executed 1156, cover 7565, crashes 0, repro 0
2019/05/30 18:30:11 VMs 1, executed 1538, cover 8745, crashes 0, repro 0
...

可以在浏览器中查看fuzz进度:

reference
How to set up syzkaller
Setup: Ubuntu host, QEMU vm, x86-64 kernel
ubuntu16.04 编译安装gcc8.2.0

Google为了解决机器学习模型部署上线至生产环境,发布了Tensorflow Serving。本文主要通过部署一个手写数字识别的模型来介绍Tensorflow Serving的基本用法。

构建CNN模型及checkpoint保存方式

如果不需要Tensorflow Serving部署模型的话,大部分人会选择传统的checkpoint方式保存训练好的模型,下面先看一下这种传统的保存方式:
代码清单 mnist_test.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import cv2

# 屏蔽waring信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""------------------加载数据---------------------"""
# 载入数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
# 改变数据格式,为了能够输入卷积层
trX = trX.reshape(-1, 28, 28, 1) # -1表示不考虑输入图片的数量,1表示单通道
teX = teX.reshape(-1, 28, 28, 1)

"""------------------构建模型---------------------"""
# 定义输入输出的数据容器
X = tf.placeholder("float", [None, 28, 28, 1], name="X")
Y = tf.placeholder("float", [None, 10])


# 定义和初始化权重、dropout参数
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01))


w1 = init_weights([3, 3, 1, 32]) # 3X3的卷积核,获得32个特征
w2 = init_weights([3, 3, 32, 64]) # 3X3的卷积核,获得64个特征
w3 = init_weights([3, 3, 64, 128]) # 3X3的卷积核,获得128个特征
w4 = init_weights([128 * 4 * 4, 625]) # 从卷积层到全连层
w_o = init_weights([625, 10]) # 从全连层到输出层

p_keep_conv = tf.placeholder("float", name="p_keep_conv")
p_keep_hidden = tf.placeholder("float", name="p_keep_hidden")


# 定义模型
def create_model(X, w1, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
# 第一组卷积层和pooling层
conv1 = tf.nn.conv2d(X, w1, strides=[1, 1, 1, 1], padding='SAME')
conv1_out = tf.nn.relu(conv1)
pool1 = tf.nn.max_pool(conv1_out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
pool1_out = tf.nn.dropout(pool1, p_keep_conv)

# 第二组卷积层和pooling层
conv2 = tf.nn.conv2d(pool1_out, w2, strides=[1, 1, 1, 1], padding='SAME')
conv2_out = tf.nn.relu(conv2)
pool2 = tf.nn.max_pool(conv2_out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
pool2_out = tf.nn.dropout(pool2, p_keep_conv)

# 第三组卷积层和pooling层
conv3 = tf.nn.conv2d(pool2_out, w3, strides=[1, 1, 1, 1], padding='SAME')
conv3_out = tf.nn.relu(conv3)
pool3 = tf.nn.max_pool(conv3_out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
pool3 = tf.reshape(pool3, [-1, w4.get_shape().as_list()[0]]) # 转化成一维的向量
pool3_out = tf.nn.dropout(pool3, p_keep_conv)

# 全连层
fully_layer = tf.matmul(pool3_out, w4)
fully_layer_out = tf.nn.relu(fully_layer)
fully_layer_out = tf.nn.dropout(fully_layer_out, p_keep_hidden)

# 输出层
out = tf.matmul(fully_layer_out, w_o)

return out


model = create_model(X, w1, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

# 定义代价函数、训练方法、预测操作
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=model, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(model, 1, name="predict")

# 定义一个saver
saver=tf.train.Saver()

# 定义存储路径
ckpt_dir="./ckpt_dir"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)

"""------------------训练模型或者加载模型进行测试---------------------"""
train_batch_size = 128 # 训练集的mini_batch_size=128
test_batch_size = 256 # 测试集中调用的batch_size=256
epoches = 5 # 迭代周期
with tf.Session() as sess:
"""-------训练模型--------"""
# 初始化所有变量
tf.global_variables_initializer().run()

# 训练操作
# for i in range(epoches):
# train_batch = zip(range(0, len(trX), train_batch_size),
# range(train_batch_size, len(trX) + 1, train_batch_size))
# for start, end in train_batch:
# sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
# p_keep_conv: 0.8, p_keep_hidden: 0.5})
# # 每个周期用测试集中随机抽出test_batch_size个图片进行测试
# test_indices = np.arange(len(teX)) # 返回一个array[0,1...len(teX)]
# np.random.shuffle(test_indices) # 打乱这个array
# test_indices = test_indices[0:test_batch_size]
#
# # 获取测试集test_batch_size章图片的的预测结果
# predict_result = sess.run(predict_op, feed_dict={X: teX[test_indices],
# p_keep_conv: 1.0,
# p_keep_hidden: 1.0})
# # 获取真实的标签值
# true_labels = np.argmax(teY[test_indices], axis=1)
#
# # 计算准确率
# accuracy = np.mean(true_labels == predict_result)
# print("epoch", i, ":", accuracy)
#
# # 保存模型
# saver.save(sess,ckpt_dir+"/model.ckpt",global_step=i)

"""-----加载模型,用导入的图片进行测试--------"""
# 载入图片
src = cv2.imread('./2.png')

# 将图片转化为28*28的灰度图
src = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
dst = cv2.resize(src, (28, 28), interpolation=cv2.INTER_CUBIC)

# 将灰度图转化为1*784的能够输入的网络的数组
picture = np.zeros((28, 28))
for i in range(0, 28):
for j in range(0, 28):
picture[i][j] = (255 - dst[i][j])
picture = picture.reshape(1, 28, 28, 1)

# 载入模型
saver.restore(sess, ckpt_dir+"/model.ckpt-4")
# 进行预测
predict_result = sess.run(predict_op, feed_dict={X: picture,
p_keep_conv: 1.0,
p_keep_hidden: 1.0})
print("你导入的图片是:", predict_result[0])

mnist_test.py文件包含了模型定义、训练和保存、加载。其中注释掉的代码为模型训练部分,去掉注释可以进行模型训练,训练后的模型保存在ckpt_dir中:

1
2
3
4
5
$ ls ckpt_dir
checkpoint model.ckpt-1.data-00000-of-00001 model.ckpt-2.index model.ckpt-3.meta
model.ckpt-0.data-00000-of-00001 model.ckpt-1.index model.ckpt-2.meta model.ckpt-4.data-00000-of-00001
model.ckpt-0.index model.ckpt-1.meta model.ckpt-3.data-00000-of-00001 model.ckpt-4.index
model.ckpt-0.meta model.ckpt-2.data-00000-of-00001 model.ckpt-3.index model.ckpt-4.meta

用Saved Model方式保存训练好的模型

为了能用Tensorflow Serving部署,需要用SavedModel方式保存模型。可以重新训练,再用SavedModel方式保存,也可以加载预训练的模型,保存为SavedModel方式。具体操作是在载入模型之后加入如下代码:

1
2
3
4
5
# Saved Model
tf.saved_model.simple_save(sess,
'savedmodel/1',
inputs={"X":X, "p_keep_conv":p_keep_conv, "p_keep_hidden":p_keep_hidden},
outputs={"predict":predict_op})

第一个参数sess为当前会话;第二个参数为保存模型的路径;第三个参数是模型输入,即使用模型进行预测时feed_dict里的变量;第四个参数是模型输出,对应模型预测时的第一个参数。
把训练模型部分的代码注释掉,运行mnist_test.py,会将模型保存为Tensorflow Serving可以用的形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
$ ls savedmodel/1
saved_model.pb variables
# 查看模型的输入输出
$ saved_model_cli show --dir savedmodel/1 --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['X'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 28, 28, 1)
name: X:0
inputs['p_keep_conv'] tensor_info:
dtype: DT_FLOAT
shape: unknown_rank
name: p_keep_conv:0
inputs['p_keep_hidden'] tensor_info:
dtype: DT_FLOAT
shape: unknown_rank
name: p_keep_hidden:0
The given SavedModel SignatureDef contains the following output(s):
outputs['predict'] tensor_info:
dtype: DT_INT64
shape: (-1)
name: predict:0
Method name is: tensorflow/serving/predict

saved_model_cli命令行工具在安装过tensorflow之后就有了。

用Tensorflow Serving部署模型

接下来就可以用Tensorflow Serving对模型进行部署了,Tensorflow Serving可以提供gRPCREST两种方式,gRPC使用8500端口,REST使用8501端口,通过-p选项来映射docker和本地的端口,前面的是本地端口,后面的是docker端口,--name选项指定docker容器的名称,--mount选项挂载文件系统到容器,-e选项设置环境变量,-t选项分配虚拟终端:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
$ sudo docker run -p 8502:8501 --name mnist_model --mount type=bind,source=/home/fanrong/tfserving/savedmodel,target=/models/mnist_test -e MODEL_NAME=mnist_test -t tensorflow/serving
2019-05-09 07:31:04.975302: I tensorflow_serving/model_servers/server.cc:82] Building single TensorFlow model file config: model_name: mnist_test model_base_path: /models/mnist_test
2019-05-09 07:31:04.975674: I tensorflow_serving/model_servers/server_core.cc:461] Adding/updating models.
2019-05-09 07:31:04.975744: I tensorflow_serving/model_servers/server_core.cc:558] (Re-)adding model: mnist_test
2019-05-09 07:31:05.076332: I tensorflow_serving/core/basic_manager.cc:739] Successfully reserved resources to load servable {name: mnist_test version: 1}
2019-05-09 07:31:05.076402: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: mnist_test version: 1}
2019-05-09 07:31:05.076447: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: mnist_test version: 1}
2019-05-09 07:31:05.076533: I external/org_tensorflow/tensorflow/contrib/session_bundle/bundle_shim.cc:363] Attempting to load native SavedModelBundle in bundle-shim from: /models/mnist_test/1
2019-05-09 07:31:05.076560: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /models/mnist_test/1
2019-05-09 07:31:05.081699: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve }
2019-05-09 07:31:05.109514: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:182] Restoring SavedModel bundle.
2019-05-09 07:31:05.146630: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:285] SavedModel load for tags { serve }; Status: success. Took 70041 microseconds.
2019-05-09 07:31:05.146734: I tensorflow_serving/servables/tensorflow/saved_model_warmup.cc:101] No warmup data file found at /models/mnist_test/1/assets.extra/tf_serving_warmup_requests
2019-05-09 07:31:05.147120: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: mnist_test version: 1}
2019-05-09 07:31:05.153768: I tensorflow_serving/model_servers/server.cc:313] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2019-05-09 07:31:05.158494: I tensorflow_serving/model_servers/server.cc:333] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 237] RAW: Entering the event loop ...

接下来要对上线的模型请求服务,需要编写客户端代码:
代码清单 client.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#coding:utf-8

import requests
import cv2
import numpy as np

URL = 'http://localhost:8502/v1/models/mnist_test:predict'

def main():
src = cv2.imread('./2.png')
src = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
dst = cv2.resize(src, (28, 28), interpolation=cv2.INTER_CUBIC)
picture = np.zeros((28,28))
for i in range(0, 28):
for j in range(0, 28):
picture[i][j] = (255 - dst[i][j])
picture = picture.reshape(1, 28, 28, 1)
predict_request = '{"inputs":{"X":%s, "p_keep_conv":1.0, "p_keep_hidden":1.0}}' % picture.tolist()
response = requests.post(URL, data=predict_request)
response.raise_for_status()
print(response.json()['outputs'])

if __name__ == '__main__':
main()

这里需要注意的是输入参数中的X,X本身是一个numpy数组类型,但是参数只能以字符串形式传递,如果直接用nunpy数组,转换为字符串是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
[[[[  0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
...
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]]]]

服务端接收到字符串以后无法处理,会报400 Client Error: Bad Request for url: http://localhost:8502/v1/models/mnist_test:predict
所以需要先把numpy数组转换为list类型,list类型转换为字符串是这样的:

1
[[[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], ...[0.0], [0.0], [0.0], [0.0], [0.0]]]]

服务端接收到字符串之后,可以识别为list类型,进行正常处理。
最后是这种效果:

1
2
$ python client.py
[2]

Ubuntu18.04安装

安装最新版Docker

1
$ wget -qO- https://get.docker.com/ | sh

Docker可以在容器内运行应用程序,使用docker run命令在容器内运行应用程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
$ sudo docker

Usage: docker [OPTIONS] COMMAND

A self-sufficient runtime for containers

Options:
--config string Location of client config files (default "/home/fanrong1/.docker")
-D, --debug Enable debug mode
-H, --host list Daemon socket(s) to connect to
-l, --log-level string Set the logging level ("debug"|"info"|"warn"|"error"|"fatal") (default "info")
--tls Use TLS; implied by --tlsverify
--tlscacert string Trust certs signed only by this CA (default "/home/fanrong1/.docker/ca.pem")
--tlscert string Path to TLS certificate file (default "/home/fanrong1/.docker/cert.pem")
--tlskey string Path to TLS key file (default "/home/fanrong1/.docker/key.pem")
--tlsverify Use TLS and verify the remote
-v, --version Print version information and quit

Management Commands:
builder Manage builds
config Manage Docker configs
container Manage containers
engine Manage the docker engine
...

使用Docker MySQL

下面以docker mysql为例介绍Docker的使用。
查找Docker Hub上的MySQL镜像

1
2
3
4
5
6
7
8
9
$ sudo docker search mysql
NAME DESCRIPTION STARS OFFICIAL AUTOMATED
mysql MySQL is a widely used, open-source relation… 8115 [OK]
mariadb MariaDB is a community-developed fork of MyS… 2753 [OK]
mysql/mysql-server Optimized MySQL Server Docker images. Create… 607 [OK]
zabbix/zabbix-server-mysql Zabbix Server with MySQL database support 191 [OK]
hypriot/rpi-mysql RPi-compatible Docker Image with Mysql 113
zabbix/zabbix-web-nginx-mysql Zabbix frontend based on Nginx web-server wi… 100 [OK]
...

拉取官方镜像,Tag为5.7

1
$ sudo docker pull mysql:5.7

查看本地镜像列表

1
2
3
4
5
$ sudo docker images
REPOSITORY TAG IMAGE ID CREATED SIZE
ubuntu 16.04 9361ce633ff1 8 weeks ago 118MB
tensorflow/serving latest 38bee21b2ca0 2 months ago 229MB
mysql 5.7 e47e309f72c8 3 months ago 372MB

启动mysql docker容器

1
2
$ sudo docker run --name test_docker_mysql -p 3307:3306 -e MYSQL_ROOT_PASSWORD=123456 -d mysql:5.7
65b2b644555b3730ac936caab2ccf9ad451b6e31f2509ed3e2333c9e20079873

查看docker启动情况

1
2
3
$ sudo docker ps
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
65b2b644555b mysql:5.7 "docker-entrypoint.s…" 5 minutes ago Up 5 minutes 33060/tcp, 0.0.0.0:3307->3306/tcp test_docker_mysql

到这已经可以使用docker的mysql容器提供的服务了,由于本地没有安装mysql客户端,所以这里用python去连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
#coding:utf-8

import pymysql

db = pymysql.connect("localhost", "root", "123456", "mysql", port=3307)
cursor = db.cursor()
sql = "select * from user"
cursor.execute(sql)
results = cursor.fetchall()
for row in results:
print(row)

db.close()

运行脚本结果:

1
2
3
4
5
$ python mysql_test.py
('localhost', 'root', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', '', b'', b'', b'', 0, 0, 0, 0, 'mysql_native_password', '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9', 'N', datetime.datetime(2019, 5, 9, 8, 47, 32), None, 'N')
('localhost', 'mysql.session', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'Y', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', '', b'', b'', b'', 0, 0, 0, 0, 'mysql_native_password', '*THISISNOTAVALIDPASSWORDTHATCANBEUSEDHERE', 'N', datetime.datetime(2019, 5, 9, 8, 47, 20), None, 'Y')
('localhost', 'mysql.sys', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', '', b'', b'', b'', 0, 0, 0, 0, 'mysql_native_password', '*THISISNOTAVALIDPASSWORDTHATCANBEUSEDHERE', 'N', datetime.datetime(2019, 5, 9, 8, 47, 20), None, 'Y')
('%', 'root', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', 'Y', '', b'', b'', b'', 0, 0, 0, 0, 'mysql_native_password', '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9', 'N', datetime.datetime(2019, 5, 9, 8, 47, 32), None, 'N')

使用Docker Tensorflow Serving部署模型

下载Tensorflow Serving的Docker镜像

1
2
$ sudo docker pull tensorflow/serving
$ git clone https://github.com/tensorflow/serving

设置demo模型的路径

1
$ TESTDATA="$(pwd)/serving/tensorflow_serving/servables/tensorflow/testdata"

启动Tensorflow Serving容器,打开REST API端口

1
$ sudo docker run -t -p 8501:8501 -v "$TESTDATA/saved_model_half_plus_two_cpu:/models/half_plus_two" -e MODEL_NAME=half_plus_two tensorflow/serving

用命令行请求模型的预测API

1
2
3
4
$ curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_two:predict
{ "predictions": [2.5, 3.0, 4.5] }
$ curl -d '{"inputs": 1.0}' -X POST http://localhost:8501/v1/models/half_plus_two:predict
{ "predictions": 2.5 }

这次主要是调试Android system/bt目录下的代码,也就是蓝牙相关的代码,是源于对几个CVE(cve-2018-9355~cve-2018-9362)的学习,这几个CVE是蓝牙相关的漏洞。

蓝牙相关代码分析

因为之前对蓝牙相关代码没有了解过,所以先对AOSP中关于蓝牙的代码进行学习。首先是从Android应用层编程开始,几个基本操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 获取设备Adapter
BluetoothAdapter mBluetoothAdapter = BluetoothAdapter.getDefaultAdapter();
// 表明此手机不支持蓝牙
if(mBluetoothAdapter == null){
return;
}
// 蓝牙未开启,则开启蓝牙
if(!mBluetoothAdapter.isEnabled()) {
mBluetoothAdapter.enable();
}
// 查询已绑定设备,发现新设备
mBluetoothAdapter.getBondedDevices();
mBluetoothAdapter.startDiscovery();

还有一些读写之类的操作先不赘述。这些API调用的是frameworks里的代码,如getDefaultAdapter()函数对应
frameworks/base/core/java/android/bluetooth/BluetoothAdapter.java文件里的getDefaultAdapter()函数:

1
2
3
4
5
6
7
8
9
10
11
12
public static synchronized BluetoothAdapter getDefaultAdapter() {
if (sAdapter == null) {
IBinder b = ServiceManager.getService(BLUETOOTH_MANAGER_SERVICE); // 这里通过Binder通信调用packages里的Android应用程序的服务
if (b != null) {
IBluetoothManager managerService = IBluetoothManager.Stub.asInterface(b);
sAdapter = new BluetoothAdapter(managerService);
} else {
Log.e(TAG, "Bluetooth binder is null");
}
}
return sAdapter;
}

上述代码片段中的BLUETOOTH_MANAGER_SERVICE是在
frameworks/base/services/core/java/com/android/server/BluetoothService.java中注册的,会通过Binder机制调用到
packages/apps/Bluetooth/src/com/android/bluetooth/btservice/AdapterService.java,AdapterService.enableNative()在
packages/apps/Bluetooth/jni/com_android_bluetooth_btservice_AdapterService.cpp中声明,通过jni调用HAL(hardware)里的代码。hardware里的代码会调用system里的代码实现,具体代码我就不分析了,主要介绍一下如何调试。

调试system中蓝牙相关代码

首先需要编译Android的源码,烧录到真机上,可以参考我之前的一篇文章,我的系统是Android6.0.1和Ubuntu 16.04。
前面分析了system的代码会由packages里的bluetooth应用程序调用,所以可以先尝试调试bluetooth的应用程序,跟踪代码到system层。接下来就是对Android应用native代码调试的操作,先手动打开蓝牙,打开一个终端(需要adb root):

1
2
3
4
5
6
$ adb shell
root@hammerhead:/ # ps | grep bluetooth
bluetooth 9134 2265 907552 46368 sys_epoll_ b6ce9894 S com.android.bluetooth
root@hammerhead:/ # gdbserver :1234 --attach 9134
Attached; pid = 9134
Listening on port 1234

另一个终端:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
$ cd Android6.0.1_r20
$ source build/envsetup.h
$ lunch 19
$ adb forward tcp:1234 tcp:1234
$ gdbclient 9134
(gdb) target remote :1234
Remote debugging using :1234
warning: Could not load shared library symbols for 9 libraries, e.g. /data/dalvik-cache/arm/system@framework@boot.oat.
Use the "info sharedlibrary" command to see the complete listing.
Do you need "set solib-search-path" or "set sysroot"?
Reading symbols from /home/fanrong/Computer/Android6.0.1-r20/out/target/product/hammerhead/symbols/system/bin/linker...done.
Loaded symbols for /home/fanrong/Computer/Android6.0.1-r20/out/target/product/hammerhead/symbols/system/bin/linker
Reading symbols from /home/fanrong/Computer/Android6.0.1-r20/out/target/product/hammerhead/symbols/system/lib/libcutils.so...done.
Loaded symbols for /home/fanrong/Computer/Android6.0.1-r20/out/target/product/hammerhead/symbols/system/lib/libcutils.so
...
Loaded symbols for /home/fanrong/Computer/Android6.0.1-r20/out/target/product/hammerhead/symbols/system/vendor/lib/libbt-vendor.so
__epoll_pwait () at bionic/libc/arch-arm/syscalls/__epoll_pwait.S:16
16 ldmfd sp!, {r4, r5, r6, r7}
(gdb)

比如我们调试蓝牙搜索的相关代码,
/packages/apps/Bluetooth/src/com/android/bluetooth/btservice/AdapterService.java中的startDiscovery()方法:

1
2
3
4
5
6
 boolean startDiscovery() {
enforceCallingOrSelfPermission(BLUETOOTH_ADMIN_PERM,
"Need BLUETOOTH ADMIN permission");

return startDiscoveryNative();
}

jni调用
/packages/apps/Bluetooth/jni/com_android_bluetooth_btservice_AdapterService.cpp中的startDiscoveryNative()

1
2
3
4
5
6
7
8
9
10
static jboolean startDiscoveryNative(JNIEnv* env, jobject obj) {
ALOGV("%s:",__FUNCTION__);

jboolean result = JNI_FALSE;
if (!sBluetoothInterface) return result;

int ret = sBluetoothInterface->start_discovery();
result = (ret == BT_STATUS_SUCCESS) ? JNI_TRUE : JNI_FALSE;
return result;
}

通过查看/packages/apps/Bluetooth/jni/Android.mk

1
LOCAL_MODULE := libbluetooth_jni

可以知道bluetooth的jni代码被编译成了libbluetooth_jni.so库,在
Android6.0.1-r20/out/target/product/hammerhead/symbols/system/lib/libbluetooth_jni.so,这个是带符号的,用IDA打开它。
打开一个终端,查看系统中so库加载的基址:

1
2
3
4
5
$ adb shell
# cat /proc/2286/maps | grep libbluetooth
b39e7000-b39fe000 r-xp 00000000 b3:19 921 /system/lib/libbluetooth_jni.so
b39ff000-b3a00000 r--p 00017000 b3:19 921 /system/lib/libbluetooth_jni.so
b3a00000-b3a01000 rw-p 00018000 b3:19 921 /system/lib/libbluetooth_jni.so

在IDA->Edit->Segments->Rebase program…,将Value改成0xb39e7000,找到android::startDiscoveryNative()函数:

在这个函数的起始地址是0xb39e9a80,在gdb中下断点,继续运行,让蓝牙对周围设备进行搜索,即可触发断点:

单步运行,可以运行到system/bt/btif/src/bluetooth.c中的start_discovery()函数(gdb中按ctrl+x+a):

/system/bin中的可执行文件调试

对于未启动进程:调试进程keystore

1
2
3
$ adb shell gdbserver :1234 /system/bin/keystore
Process /system/bin/keystore created; pid = 3990
Listening on port 1234

另一个终端(需要先source、lunch)

1
2
3
4
5
6
7
8
$ adb forward tcp:1234 tcp:1234
$ gdbclient 3990
Reading symbols from /home/fanrong/Computer/Android6.0.1-r20/out/target/product/hammerhead/symbols/system/bin/keystore...done.
(gdb) target remote :1234
__dl__start () at bionic/linker/arch/arm/begin.S:32
32 mov r0, sp
(gdb) si
33 bl __linker_init

对于已经启动的进程

1
2
$ adb shell
# gdbserver :1234 --attach pid

另一个终端

1
2
3
$ adb forward tcp:1234 tcp:1234
$ gdbclient pid
(gdb) target remote :1234

reference
蓝牙流程介绍
Android FrameWork学习(二)Android系统源码调试

继续上一篇文章的内容,我们已经有一个可以证实的区块链了,但是现在链中只保存着一些无用的信息,这篇文章中我们将会实现简单的钱包(wallet)交易(transaction),用交易来替换这些数据,创建一个非常简单的加密货币:”NoobCoin”。

创建Wallet

在加密货币中,币的所有权在区块链上通过交易转移,参与者有一个地址来发送和接收币。基本形式的钱包只能保存地址,但大多数钱包是能够在区块链上创建新交易的软件。

上图显示了钱包、交易和区块链的关系,钱包可以创建交易,交易被存储在区块中,一个区块中可以有多个交易。具体的内容下面会详细介绍。
首先创建一个钱包类来保存私钥和公钥,UTXOs的作用后面再介绍:
代码清单 wallet.py

1
2
3
4
5
6
7
8
9
10
import random
from ecdsa.util import PRNG
from ecdsa import SigningKey

class Wallet:
def __init__(self):
rng = PRNG(str(random.random()))
self.privateKey = SigningKey.generate(entropy=rng)
self.publicKey = self.privateKey.get_verifying_key()
self.UTXOs = {}

在noobcoin中,公钥实际上是作为地址使用的,被共享给其他人来接收付款。私钥是用来签名交易的,这样除了私钥的所有者,没有人能花我们的币。公钥也和交易一起发送,用来验证签名是合法的,数据没有被篡改。

这里用的椭圆曲线加密来产生公私钥对,用到了一个第三方库python-ecdsa

交易和签名

每个交易都会包含以下数据:

  • 资金发送者的公钥(地址)
  • 资金接收者的公钥(地址)
  • 转移资金的数目
  • 输入(inputs),引用先前的交易,证明发送者有资金可以发送
  • 输出(outputs),显示了接收资金的所有地址,这些输出在新交易中会被当做输入
  • 加密签名,证明地址的所有者是发送这个交易的人,并且数据没有被篡改

下面创建交易类:
代码清单 transaction.py

1
2
3
4
5
6
7
8
9
10
11
12
13
class Transaction:
def __init__(self, pubkey_from, pubkey_to, value, inputs):
self.transactionId = ""
self.sender = pubkey_from
self.reciepient = pubkey_to
self.value = value
self.inputs = inputs
self.outputs = []
self.sequence = 0

def calculateHash(self):
self.sequence += 1
return StringUtils.sha256(self.sender.to_string() + self.reciepient.to_string() + str(self.value) + str(self.sequence))

在创建产生签名的方法之前,我们需要先在StringUtils类中加一些辅助函数:
代码清单 strutils.py

1
2
3
4
5
6
7
@staticmethod
def applyECDSASig(prikey, inputs):
return prikey.sign(inputs)

@staticmethod
def verifyECDSASig(pubkey, data, signature):
return pubkey.verify(signature, data)

现在我们把签名方法用到Transaction类中,添加一个generateSignature()和verifySignature()方法:
代码清单 transaction.py

1
2
3
4
5
6
7
def generateSignature(self, prikey):
data = self.sender.to_string() + self.reciepient.to_string() + str(self.value)
self.signature = StringUtils.applyECDSASig(prikey, data)

def verifySignature(self):
data = self.sender.to_string() + self.reciepient.to_string() + str(self.value)
return StringUtils.verifyECDSASig(self.sender, data, self.signature)

实际中会签名更多的信息,例如outputs/inputs和timestamp等,但现在我们只签名最少的。一个新的交易添加到一个区块时,矿工将会验证签名。

测试钱包和签名

测试一下上面的代码:
代码清单 noobchain.py

1
2
3
4
5
6
7
8
9
10
11
12
13
from wallet import Wallet
from transaction import Transaction
walletA = Wallet()
walletB = Wallet()

print "Private and public keys: "
print walletA.privateKey.to_pem()
print walletA.publicKey.to_pem()

transaction = Transaction(walletA.publicKey, walletB.publicKey, 5, None);
transaction.generateSignature(walletA.privateKey);
print "Is signature verified"
print transaction.verifySignature()

运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
$ python noobchain.py
Private and public keys:
-----BEGIN EC PRIVATE KEY-----
MF8CAQEEGIAnlP0QuZpVjYN8cF1g0VK+QIDhpySdUqAKBggqhkjOPQMBAaE0AzIA
BF5wfHb88KlyjGkVeSSnqge/4Q4vxDIlHkwu5BvrszPLaEHs00DVwFEFdmjp5wkv
vw==
-----END EC PRIVATE KEY-----

-----BEGIN PUBLIC KEY-----
MEkwEwYHKoZIzj0CAQYIKoZIzj0DAQEDMgAEXnB8dvzwqXKMaRV5JKeqB7/hDi/E
MiUeTC7kG+uzM8toQezTQNXAUQV2aOnnCS+/
-----END PUBLIC KEY-----

Is signature verified
True

上面的代码创建了两个钱包,打印了walletA的公私钥。创建了一个交易,用walletA的私钥给它签名。
接下来介绍outputs和inputs,并把交易保存到区块链。

outputs & inputs

加密货币是如何被拥有的

你要拥有一个比特币需要先接收一个比特币,分布式账本并不会真的添加一个比特币给你,减少发送者的一个比特币,发送者会引用以前的交易证明他有一个比特币并作为input,然后创建一个交易output来证明发送了一个比特币到你的地址。(交易的input就是引用先前的交易output作为证明)
你钱包的余额就是所有未使用的地址是你的交易output的总和。
这里我们遵循比特币的惯例,称未使用的交易输出为:UTXO.
下面来创建TransactionInput和TransactionOutput类:
代码清单 transaction.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class TransactionInput:
def __init__(self, transactionOutputId):
self.transactionOutputId = transactionOutputId
self.UTXO = None # 包含未使用的交易output

class TransactionOutput:
def __init__(self, pubkey_to, value, parentTransactionId):
self.reciepient = pubkey_to
self.value = value
self.parentTransactionId = parentTransactionId
self.id = StringUtils.sha256(pubkey_to.to_string()+str(value)+parentTransactionId)

def isMine(self, publicKey):
return publicKey == self.reciepient

大致的工作流程:初始钱包手动创建一个交易,给A钱包100个币,这个交易没有input,只有output。output显示接收者为A钱包,金额是100,output会被添加到全局UTXOs,这个交易会被添加到第0个区块,不会被验证。钱包A要给钱包B发送40个币,交易的input就是引用上一笔交易的output,证明自己有足够的币可以发送,交易会有两个output,一个是接收者为B钱包,金额是40,另一个是接收者为A钱包,金额是60(找零)。这两个output会被添加到全局UTXOs,并将交易的input引用的output从全局UTXOs中删除,这笔交易会被添加到区块上。

处理交易

链上的区块可能会接收到很多交易,区块链可能会很长,处理一个新的交易会花费大量的时间,因为我们必须要验证它的inputs。为了解决这个问题,就需要保存一个额外的集合:所有可用来做input的未花费output,因此我们新建一个config.py文件:

1
2
3
4
5
blockchain = []
# 未使用的output
UTXOs = {}
minimumTransaction = 0.1
difficulty = 3

在Transaction类中添加一个processTransaction()方法:
代码清单 transaction.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def processTransaction(self):
if self.verifySignature() == False:
print "#Transaction Signature failed to verify"
return False
# 收集交易input(确保他们是未花费的)
for i in self.inputs:
i.UTXO = config.UTXOs[i.transactionOutputId]
# 输入的和不能小于最小交易额
if self.getInputsValue() < config.minimumTransaction:
print "#Transaction Inputs too small: " + self.getInputsValue()
return False
# 产生交易输出
leftOver = self.getInputsValue() - self.value
self.transactionId = self.calculateHash()
# 发送value给reciptient
self.outputs.append(TransactionOutput(self.reciepient, self.value, self.transactionId))
# 发送找零给sender
self.outputs.append(TransactionOutput(self.sender, leftOver, self.transactionId))
for o in self.outputs:
config.UTXOs[o.id] = o
# 从UTXO删除交易inputs当做花费
for i in self.inputs:
if i.UTXO == None:
continue
del config.UTXOs[i.UTXO.id]
return True

# 返回输入值的和
def getInputsValue(self):
total = 0
for i in self.inputs:
if i.UTXO == None:
continue
total += i.UTXO.value
return total

# 返回输出值的和
def getOutputsValue(self):
total = 0
for o in self.outputs:
total += o.value
return total

我们用这个方法来做一些检查,确保交易是合法的。最后,我们从UTXOs中删除了input引用到的output,意味着一个交易的output只能被用做一次input。
最后再来更新钱包:

  • 收集我们的余额(通过遍历UTXOs列表来检查交易的output是否是自己的)
  • 为我们创建交易

代码清单 wallet.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import config
from transaction import Transaction, TransactionInput

class Wallet:
def __init__(self):
...
# 计算自己的余额
def getBalance(self):
total = 0
for item in config.UTXOs.items():
UTXO = item[1]
if UTXO.isMine(self.publicKey): # 将全局UTXOs中是自己的output添加到自己钱包的UTXOs
self.UTXOs[UTXO.id] = UTXO
total += UTXO.value
return total
# 发送金额给其他人
def sendFunds(self, _recipient, value):
if self.getBalance() < value:
print "#Not Enough funds to send transaction. Transaction Discarded."
return None
inputs = []
total = 0
# 需要引用到的output
for item in self.UTXOs.items():
UTXO = item[1]
total += UTXO.value
inputs.append(TransactionInput(UTXO.id))
if total > value:
break
# 创建交易
newTransaction = Transaction(self.publicKey, _recipient, value, inputs)
newTransaction.generateSignature(self.privateKey)
# 删除引用过的output
for i in inputs:
del self.UTXOs[i.transactionOutputId]
return newTransaction

添加交易到区块上

现在交易系统已经可以运行,我们需要把它实现到区块链上。把区块链上无用的数据替换为一个交易的列表。一个区块上可能会有上千个交易,对于计算hash来说需要包含的太多。因此,我们使用交易的merkle root。在StringUtils类中添加一个产生merkleroot的辅助方法:
代码清单 strutils.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@staticmethod
def getMerkleRoot(transactions):
count = len(transactions)
previousTreeLayer = []
for transaction in transactions:
previousTreeLayer.append(transaction.transactionId)
treeLayer = previousTreeLayer
while count > 1:
treeLayer = []
for i in range(1, len(previous)):
treeLayer.append(sha256(previousTreeLayer[i-1]+previousTreeLayer[i]))
count = len(treeLayer)
previousTreeLayer = treeLayer
merkleRoot = treeLayer[0] if len(treeLayer) == 1 else ""
return merkleRoot

接下来修改Block类:
代码清单 block.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Block:
def __init__(self, previousHash):
self.previousHash = previousHash
self.timeStamp = time.time()
self.transactions = []
self.nonce = 0
self.merkleRoot = ""
self.hash = self.calculateHash()

def calculateHash(self):
return StringUtils.sha256(self.previousHash+str(self.timeStamp)+str(self.nonce)+self.merkleRoot)

def mineBlock(self, difficulty):
self.merkleRoot = StringUtils.getMerkleRoot(self.transactions)
target = '0'*difficulty
while self.hash[0:difficulty] != target:
self.nonce += 1
self.hash = self.calculateHash()
print "Block Mined!!! : " + self.hash
# 将交易添加到这个区块上
def addTransaction(self, transaction):
if transaction == None:
return False
if self.previousHash != "0":
if transaction.processTransaction() != True:
print "Transaction failed to process. Discarded."
return False
self.transactions.append(transaction)
print "Transaction Successfully added to Block"
return True

运行noobcoin

我们需要测试用钱包发送和接收币,更新我们区块链的合法性检查。有许多方法来创建新币,在比特币区块链,矿工可以用一笔交易作为挖到区块的奖励。现在,我们直接释放我们想用到的所有币,在第0个区块(genesis block)。就像比特币,我们将硬编码Genesis区块。
我们来更新noobchain.py文件:

  • 一个Genesis区块,释放100个noobcoin给walletA
  • 一个更新的链合法性检查,将交易考虑在内
  • 一些测试交易,来看看是否运行正常

代码清单 noobchain.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import config
from block import Block
from wallet import Wallet
from transaction import Transaction, TransactionOutput, TransactionInput

walletA = Wallet()
walletB = Wallet()
coinbase = Wallet()

# 将coinbase里的100个coin给A
genesisTransaction = Transaction(coinbase.publicKey, walletA.publicKey, 100, None)
genesisTransaction.generateSignature(coinbase.privateKey)
genesisTransaction.transactionId = "0"
genesisTransaction.outputs.append(TransactionOutput(genesisTransaction.reciepient, genesisTransaction.value, genesisTransaction.transactionId))
config.UTXOs[genesisTransaction.outputs[0].id] = genesisTransaction.outputs[0]

print "Creating and Mining Genesis block... "
genesis = Block("0")
genesis.addTransaction(genesisTransaction)
genesis.mineBlock(config.difficulty)
config.blockchain.append(genesis)

# testing
block1 = Block(genesis.hash)
print "walletA's balance is: %f" % walletA.getBalance()
print "walletA is attempting to send funds (40) to walletB..."
block1.addTransaction(walletA.sendFunds(walletB.publicKey, 40))
block1.mineBlock(config.difficulty)
config.blockchain.append(block1)
print "walletA's balance is: %f" % walletA.getBalance()
print "walletB's balance is: %f" % walletB.getBalance()

block2 = Block(block1.hash)
print "walletA is attempting to send more funds (1000) than it has..."
block2.addTransaction(walletA.sendFunds(walletB.publicKey, 1000))
block2.mineBlock(config.difficulty)
config.blockchain.append(block2)
print "walletA's balance is: %f" % walletA.getBalance()
print "walletB's balance is: %f" % walletB.getBalance()

block3 = Block(block2.hash)
print "walletB is attempting to send funds (20) to walletA..."
block3.addTransaction(walletB.sendFunds(walletA.publicKey, 20))
print "walletA's balance is: %f" % walletA.getBalance()
print "walletB's balance is: %f" % walletB.getBalance()

运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
$ python noobchain.py
Creating and Mining Genesis block...
Transaction Successfully added to Block
Block Mined!!! : 000ecc366d3c553e131693c972dc77b7f2bb7c8156aa7125bb10c9e4052d8874
walletA's balance is: 100.000000
walletA is attempting to send funds (40) to walletB...
Transaction Successfully added to Block
Block Mined!!! : 0005af73ad51e2a0fa80444bdd284342c102dc202c9139889f62a3f989d8343d
walletA's balance is: 60.000000
walletB's balance is: 40.000000
walletA is attempting to send more funds (1000) than it has...
#Not Enough funds to send transaction. Transaction Discarded.
Block Mined!!! : 00035606ee8cd61c0afb21c496193fdfc67a17172befd6cae90d7a9410a27d5f
walletA's balance is: 60.000000
walletB's balance is: 40.000000
walletB is attempting to send funds (20) to walletA...
Transaction Successfully added to Block
walletA's balance is: 80.000000
walletB's balance is: 20.000000

我们的钱包现在可以在区块链上发送资金了,我们有了一个本地的加密货币。最后,再加上检查区块链合法性的方法即可:
代码清单 noobchain.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def isChainValid():
hashTarget = '0' * config.difficulty
tempUTXOs = {}
tempUTXOs[genesisTransaction.outputs[0].id] = genesisTransaction.outputs[0]
for i in range(1, len(config.blockchain)):
curblock = config.blockchain[i]
preblock = config.blockchain[i-1]
if curblock.hash != curblock.calculateHash():
print 'Current Hashes not equal'
return False
if preblock.hash != curblock.previousHash:
print 'Previous Hashes not equal'
return False
if curblock.hash[:config.difficulty] != hashTarget:
print '#This block hasnt been mined'
return False
for t in range(len(curblock.transactions)):
curTransaction = curblock.transactions[t]
if not curTransaction.verifySignature():
print '#Signature on Transaction %d is Invalid' % t
return False
if curTransaction.getInputsValue() != curTransaction.getOutputsValue():
print '#Inputs are not equal to outputs on Transaction %d' % t
return False
for i in curTransaction.inputs:
tempOutput = tempUTXOs[i.transactionOutputId]
if tempOutput == None:
print '#Referenced input on Transaction %d is Missing' % t
return False
if i.UTXO.value != tempOutput.value:
print '#Referenced input Transaction %d value is Invalid' % t
return False
del tempUTXOs[i.transactionOutputId]
for o in curTransaction.outputs:
tempUTXOs[o.id] = o
if curTransaction.outputs[0].reciepient != curTransaction.reciepient:
print '#Transaction %d output reciepient is not who it should be' % t
return False
if curTransaction.outputs[1].reciepient != curTransaction.sender:
print '#Transaction %d output change is not sender' % t
return False
print 'Blockchain is valid'
return True

项目完整代码
reference
https://medium.com/programmers-blockchain/creating-your-first-blockchain-with-java-part-2-transactions-2cdac335e0ce
https://github.com/warner/python-ecdsa
https://www.cnblogs.com/fengzhiwu/p/5524324.html

支持向量机是一种二分类模型。它的基本模型是定义在特征空间上的间隔最大的线性分类器,支持向量机还包括核技巧,这使它成为实质上的非线性分类器。支持向量机的学习策略就是间隔最大化,可形式化为一个求解凸二次规划的问题。
支持向量机学习方法包含构建由简至繁的模型:

  • 当训练数据线性可分时,通过硬间隔最大化(hard margin maximization),学习一个线性的分类器,即线性可分支持向量机。
  • 当训练数据近似线性可分时,通过软间隔最大化(soft margin maximization),也学习一个线性的分类器,即线性支持向量机。
  • 当训练数据线性不可分时,通过使用核技巧(kernel trick)及软间隔最大化,学习非线性支持向量机。

间隔与支持向量

给定训练样本集$D={(\boldsymbol{x_1},y_1),(\boldsymbol{x_2},y_2),…,(\boldsymbol{x_m},y_m)}, y_m\in{-1,+1}$,分类学习最基本的想法就是基于训练集D在样本空间中找到一个划分超平面,将不同类别的样本分开。

能将两类训练样本分开的超平面有很多,应该找两类训练样本正中间的划分超平面,因为该划分超平面对训练样本局部扰动的容忍性最好,产生的分类结果是最鲁棒的,对未见实例的泛化能力最强。
在样本空间中,划分超平面可通过如下线性方程来描述:

$$
\begin{aligned}
\boldsymbol{w^{T}x}+b=0
\end{aligned}\tag{1}
$$
其中$\boldsymbol{w}=(w_1;w_2;…;w_d)$为法向量,决定了超平面的方向,b为位移项,决定了超平面与原点之间的距离。划分超平面可被法向量$\boldsymbol{w}$和位移b确定,下面将其记为$(\boldsymbol{w},b)$。样本空间中任意点$\boldsymbol{x}$到超平面$(\boldsymbol{w},b)$的距离可写为

$$
\begin{aligned}
r=\frac{|\boldsymbol{w^Tx}+b|}{||\boldsymbol{w}||}
\end{aligned}\tag{2}
$$
假设超平面$(\boldsymbol{w},b)$能将训练样本正确分类,即对于$(\boldsymbol{x_i},y_i)\in D$,若$y_i=+1$,则有$\boldsymbol{w^Tx_i}+b>0$;若$y_i=-1$,则有$\boldsymbol{w^Tx_i}+b<0$。令

$$
\begin{cases}
\boldsymbol{w^Tx_i}+b \geqslant +1, & y_i=+1; \\
\boldsymbol{w^Tx_i}+b \leqslant -1, & y_i=-1.
\end{cases}\tag{3}
$$
如下图所示,距离超平面最近的这几个训练样本点使式(3)等号成立,它们被称为支持向量(support vector),这两类支持向量到超平面的距离之和为

$$
\begin{aligned}
\gamma=\frac{2}{||\boldsymbol{w}||}
\end{aligned}\tag{4}
$$
它被称为间隔(margin)

要找到具有最大间隔(maximum margin)的划分超平面,也就是要找到能满足式(3)中约束的参数$\boldsymbol{w}$和$b$,使得$\gamma$最大,即

$$
\begin{aligned}
& \max_{\boldsymbol{w},b}\frac{2}{||\boldsymbol{w}||} \\
& s.t. \ y_i(\boldsymbol{w^Tx_i}+b) \geqslant 1, \quad i=1,2,…,m.
\end{aligned}\tag{5}
$$
为了最大化间隔,需最大化$||\boldsymbol{w}||^{-1}$,等价于最小化$||\boldsymbol{w}||^2$。因此,式(5)可写为

$$
\begin{aligned}
& \min_{\boldsymbol{w},b}\frac{1}{2}||\boldsymbol{w}||^2 \\
& s.t. \ y_i(\boldsymbol{w^Tx_i}+b) \geqslant 1, \quad i=1,2,…,m.
\end{aligned}\tag{6}
$$
这就是支持向量机(Support Vector Machine)的基本型。

对偶问题

式(6)是一个凸二次规划(convex quadratic programming)问题,能直接用现成的优化计算包求解,但还有更高效的办法。对式(6)使用拉格朗日乘子法可得到其对偶问题(dual problem)。对式(6)的每条约束添加拉格朗日乘子$\alpha_i\geqslant0$,则该问题的拉格朗日函数可写为

$$
\begin{aligned}
L(\boldsymbol{w},b,\boldsymbol{\alpha})=\frac{1}{2}||\boldsymbol{w}||^2+\sum_{i=1}^m\alpha_i(1-y_i(\boldsymbol{w^Tx_i}+b)),
\end{aligned}\tag{7}
$$
其中$\boldsymbol{\alpha}=(\alpha_1;\alpha_2;…;\alpha_m)$。令$L(\boldsymbol{w},b,\boldsymbol{\alpha})$对$\boldsymbol{w}$和$b$的偏导为零可得

$$
\begin{eqnarray}
\boldsymbol{w} &=& \sum_{i=1}^m\alpha_iy_ix_i, \tag{8} \\
0 &=& \sum_{i=1}^m\alpha_iy_i. \tag{9}
\end{eqnarray}
$$
将式(8)代入(7),即可将$L(\boldsymbol{w},b,\boldsymbol{\alpha})$中的$\boldsymbol{w}$和$b$消去,再考虑式(9)的约束,就得到式(6)的对偶问题

$$
\begin{aligned}
& \max_\alpha\sum_{i=1}^m\alpha_i-\frac{1}{2}\sum_{i=1}^m\sum_{j=1}^m\alpha_i\alpha_jy_iy_j\boldsymbol{x_i^Tx_j} \\
& s.t. \ \sum_{i=1}^m\alpha_iy_i=0, \\
& \alpha_i \geqslant 0, \qquad i=1,2,…,m.
\end{aligned}\tag{10}
$$
解出$\alpha$后,求出$\boldsymbol{w}$与$b$即可得到模型

$$
\begin{aligned}
f(x)&=\boldsymbol{w^Tx}+b \\
&=\sum_{i=1}^m\alpha_iy_i\boldsymbol{x_i^Tx}+b
\end{aligned}\tag{11}
$$
从对偶问题(10)解出的$\alpha_i$是式(7)中的拉格朗日乘子,它对应着训练样本$(\boldsymbol{x_i},y_i)$。注意到式(6)中有不等式约束,因此上述过程需满足KKT(Karush-Kuhn-Tucker)条件,即要求

$$
\begin{cases}
\alpha_i\geqslant 0; \\
y_if(x_i)-1\geqslant 0; \\
\alpha_i(y_if(x_i)-1)=0.
\end{cases}\tag{12}
$$
于是,对任意训练样本$(\boldsymbol{x_i},y_i)$,总有$\alpha_i=0$或$y_if(\boldsymbol{x_i})=1$。若$\alpha_i=0$,则该样本不会在式(11)的求和中出现,也就不会对$f(x)$有任何影响;若$\alpha_i>0$,则必有$y_if(\boldsymbol{x_i})=1$,所对应的样本点位于最大间隔边界上,是一个支持向量。支持向量机的一个重要性质:训练完成后,大部分的训练样本都不需要保留,最终模型仅与支持向量有关。
式(10)是一个二次规划问题,可以用通用的二次规划算法求解,但是该问题的规模正比于训练样本数,人们通过利用问题本身的特性,提出了很多高效算法,SMO(Sequential Minimal Optimization)是最流行的一种。
SMO的思想是每次选取两个变量$\alpha_i$和$\alpha_j$,并固定其他的参数$\alpha_k$,求解式(10)获得更新后的$\alpha_i$和$\alpha_j$,不断迭代直至收敛。SMO采用了一个启发式:使选取的两变量所对应样本之间的间隔最大。仅考虑$\alpha_i$和$\alpha_j$时,式(10)中的约束可重写为:

$$
\begin{aligned}
\alpha_iy_i+\alpha_jy_j=-\sum_{k\neq i,j}\alpha_ky_k=c
\end{aligned}\tag{13}
$$
消去式(10)中的变量$\alpha_j$,则得到一个关于$\alpha_i$的单变量二次规划问题,仅有的约束是$\alpha_i\geqslant0$。这样的二次规划具有闭式解,不必调用数值优化算法即可高效地计算出更新后的$\alpha_i$和$\alpha_j$。
根据式(8)可求出$\boldsymbol{w}$,对于$b$,可以用任意一个支持向量的性质$y_s(\boldsymbol{w^Tx_s}+b)=1$来计算。当然现实任务中采用更鲁棒的做法,使用所有支持向量求解的平均值。

软间隔支持向量机

基础型的SVM的假设所有样本在样本空间是线性可分的(硬间隔),但现实中的情况通常不满足这种特性。为此,要引入软间隔(soft margin)的概念。

允许某些样本不满足约束 $y_i(\boldsymbol{w^Tx_i}+b)\geqslant1$,当然,在最大化间隔的同时,不满足约束的样本应尽可能少。

核函数

前面的讨论中,假设的训练样本是线性可分的,即存在一个划分超平面能将训练样本正确分类。然而现实任务中,原始样本空间内也许并不存在能正确划分两类样本的超平面。

这样的问题,可将样本从原始空间映射到一个更高维的特征空间,使得样本在这个特征空间内线性可分。令$\phi(\boldsymbol{x})$表示将$\boldsymbol{x}$映射后的特征向量,在特征空间中划分超平面所对应的模型可表示为

$$
\begin{aligned}
f(\boldsymbol{x})=\boldsymbol{w^T}\phi(\boldsymbol{x})+b
\end{aligned}\tag{14}
$$
其中$\boldsymbol{w}$和$b$是模型参数,间隔最大化类似式(6)

$$
\begin{aligned}
& \min_{\boldsymbol{w},b}\frac{1}{2}||\boldsymbol{w}||^2 \\
& s.t. \ y_i(\boldsymbol{w^T}\phi(\boldsymbol{x_i})+b)\geqslant1, \quad i=1,2,…,m.
\end{aligned}\tag{15}
$$
其对偶问题是

$$
\begin{aligned}
& \max_{\alpha}\sum_{i=1}^m\alpha_i-\frac{1}{2}\sum_{i=1}^m\sum_{j=1}^m\alpha_i\alpha_jy_iy_j\phi(\boldsymbol{x_i})\cdot\phi(\boldsymbol{x_j}) \\
& s.t. \ \sum_{i=1}^m\alpha_iy_i=0, \\
& \alpha_i\geqslant0, \quad i=1,2,…,m.
\end{aligned}\tag{16}
$$
式中$\phi(\boldsymbol{x_i})\cdot\phi(\boldsymbol{x_j})$为$\phi(\boldsymbol{x_i})$和$\phi(\boldsymbol{x_j})$的内积。
核函数的定义: 设$\mathcal{X}$是输入空间(欧式空间$\boldsymbol{R^n}$的子集或离散集合),又设$\mathcal{H}$为特征空间,如果存在一个从$\mathcal{X}$到$\mathcal{H}$的映射

$$
\phi(\boldsymbol{x}):\mathcal{X} \to \mathcal{H}
$$
使得对所有$\boldsymbol{x_i,x_j}\in\mathcal{X}$,函数$K(\boldsymbol{x_i,x_j})$满足条件

$$
K(\boldsymbol{x_i,x_j})=\phi(\boldsymbol{x_i})\cdot\phi(\boldsymbol{x_j})
$$
则称$K(\boldsymbol{x_i,x_j})$为核函数,$\phi(\boldsymbol{x})$为映射函数。
核技巧的想法是,在学习与预测中只定义核函数$K(\boldsymbol{x_i,x_j})$,而不显式定义映射函数。通常,直接计算核函数比较容易,计算映射函数的内积很困难,因为特征空间的维数可能很高,甚至是无穷维。

reference
《机器学习》
《统计学习方法》

分类决策树模型是一种描述对实例进行分类的树形结构,决策树由结点和有向边组成,结点有两种类型:内部结点和叶结点。内部结点表示一个特征或属性,叶结点表示一个类。
用决策树分类,从根结点开始,对实例的某一特征进行测试,根据测试结果,将实例分配到其子结点;这时,每一个子结点对应着该特征的一个取值。如此递归地对实例进行测试并分配,直至达到叶结点。最后将实例分到叶结点的类中。决策树学习的目的是为了产生一个泛化能力强,即处理未见示例能力强的决策树,其基本流程遵循简单且直观的分而治之(divide-and-conquer)策略。

决策树的构造

在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则已正确地划分数据分类。
下面我们将用ID3算法划分数据集(还有C4.5算法和CART算法,这里先不讨论)。每次划分数据集时我们只选取一个特征属性,那么就需要知道选哪些特征作为划分的参考属性。
下表中有5个海洋动物,特征包括不浮出水面是否可以生存,是否有脚蹼。我们可以将这些动物分成两类:鱼类和非鱼类。要决定用第一个特征还是第二个特征划分数据,需要用到量化的方法——信息增益

序号 不浮出水面可以生存 有脚蹼 属于鱼类
1
2
3
4
5

信息增益

划分数据集的大原则是:将无序的数据变得更加有序。在划分数据集前后信息发生的变化称为信息增益(information gain),可以计算每个特征值划分数据集获得的信息增益,信息增益越大,说明划分之后的信息熵越小,即数据更加有序。所以信息增益最高的特征就是我们要选择的。
信息论中度量信息的方式称为熵(entropy),熵定义为信息的期望值。假定当前样本集合$D$中第$i$类样本所占的比例为$p_i$(i = 1,2,…,n),则$D$的信息熵定义为

$$
H(D) = -\sum_{i=1}^n p_i log_2 p_i
$$

$H(D)$的值越小,则$D$越有序。
假定离散属性$a$有$K$个可能的取值{$a^1,a^2,…,a^K$},用$a$对样本集$D$进行划分,则会产生$K$个分支结点,其中第$k$个分支结点包含了$D$中所有属性$a$的属性值为$a^k$的样本,记为$D^k$。根据上面信息熵公式可以算出$H(D^k)$,再考虑到各个分支所包含的样本数不同,给其赋予权重$|D^v|/|D|$,于是可以计算出属性$a$对样本集$D$进行划分的信息增益:

$$
G(D,a) = H(D)-\sum_{k=1}^K \frac{|D^k|}{|D|} H(D^k)
$$

划分数据集

下面用Python来计算信息熵,创建trees.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from math import log

def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 计算信息数目
labelCounts = {}
# 将信息保存到字典中
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# 计算每种类别的概率
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt

接下来在trees.py中创建简单鱼鉴定数据集:

1
2
3
4
5
6
7
8
9
10
11
12
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels

myDat, labels = createDataSet()
print myDat
print calcShannonEnt(myDat)

运行结果:

1
2
3
$ python trees.py
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
0.970950594455

可以在数据集中添加数据来观察信息熵的变化情况。
另一种度量集合无需程序的方法是基尼不纯度(Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。本文对基尼不纯度也先不讨论。
接下来我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分是最好的。

1
2
3
4
5
6
7
8
9
10
11
12
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# 去掉给定特征,重新构建数据集
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet

print splitDataSet(myDat, 0, 1)

运行结果:

1
2
3
4
$ python trees.py
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
0.970950594455
[[1, 'yes'], [1, 'yes'], [0, 'no']]

下面我们遍历整个数据集,计算按各个特征划分数据集的信息增益,找到最好的特征划分方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def chooseBestFeatureToSplit(dataSet):
# 计算共有多少个特征
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
# 遍历所有的特征
for i in range (numFeatures):
# 特征i有哪些值
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 遍历所有的特征值
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
# 计算按特征i划分后各分支的权重
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature

print chooseBestFeatureToSplit(myDat)

运行结果:

1
2
3
4
5
$ python trees.py
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
0.970950594455
[[1, 'yes'], [1, 'yes'], [0, 'no']]
0

递归构建决策树

构建决策树的流程是:获得原始数据集,然后基于最好的特征划分数据集,第一次划分之后,数据将被向下传递到树分支的下一个结点,下一个结点再次划分数据,因此可以采用递归方式。
递归结束的条件是:程序遍历完所有划分数据集的特征,或者每个分支下的所有实例都具有相同的分类。
如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要采用多数表决的方法决定该叶子结点的分类。

1
2
3
4
5
6
7
8
import operator
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]

创建决策树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征,返回出现最多的类
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree

print createTree(myDat, labels)

运行结果:

1
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

决策树最终以嵌套字典的方式表示,最左边是第一个划分特征,不浮出水面,0表示不是鱼类,1再以脚蹼划分。这棵树包含了3个叶子结点和2个判断结点。

reference
《机器学习实战》
《机器学习》
http://atlantic8.github.io/2017/03/02/Decision-Tree/

关于卷积神经网络的原理这篇文章先不做介绍,推荐机器视角:长文揭秘图像处理和卷积神经网络架构卷积:如何成为一个很厉害的神经网络这两篇文章,这里记录一下PyTorch对卷积神经网络的实现。
代码清单 model.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn

# 卷积神经网络(两个卷积层)
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes) # full connection,即输出层

def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out

这里首先构建了一个两层的卷积神经网络,torch.nn.Squential是一个顺序的容器,模块按照构造函数顺序加入。
torch.nn.Conv2d是将2D卷积(卷积核是2D的)应用到输入,其参数依次为:

  • in_channels 输入图像的通道数
  • out_channels 卷积产生的通道数
  • kernel_size 卷积核大小
  • stride 卷积的步长
  • padding 在输入的各边0填充

输入为一个图像的Tensor形式,其中包含$N$词袋大小,$C$是通道数,$H$是输入的高度,$W$是输入的宽度,单位是pixel。下面的demo.py中可以看到一个图像的Tensor如何构造。
torch.nn.BatchNorm2dBatch Normalization对每个隐藏层的输入进行标准化,有加速收敛等好处。
torch.nn.ReLU是一个非线性激活函数,$ReLU(x) = max(0, x)$正数不变,负数都转化为0。
torch.nn.MaxPool2d是对输入进行一个2D最大池化。
torch.nn.Linear对输入的数据进行一个线性转变$y = Ax + b$,参数依次为:

  • in_features 输入样本的大小
  • out_features 输出样本的大小,这里是10个数字
  • bias

代码清单 train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from model import ConvNet

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)

model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)

# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)

# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
1
2
3
4
5
6
7
8
$ python train.py
Epoch [1/5], Step [100/600], Loss: 0.1743
Epoch [1/5], Step [200/600], Loss: 0.1452
Epoch [1/5], Step [300/600], Loss: 0.1029
Epoch [1/5], Step [400/600], Loss: 0.0549
Epoch [1/5], Step [500/600], Loss: 0.0608
...
Test Accuracy of the model on the 10000 test images: 99 %

这里首先介绍一下训练数据,PyTorch支持多种数据集,上述代码用到的是MNIST数据集,一般在训练和测试时都是直接使用官网提供的二进制数据集,但是作为初学者,不太理解这个二进制到底是什么东西,所以先用下面的代码将二进制文件解析为原始的图片和对应的标签:
代码清单 resolve.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from PIL import Image
import struct

def read_image(filename):
f = open(filename, 'rb')
index = 0
buf = f.read()
f.close()
magic, images, rows, columns = struct.unpack_from('>IIII' , buf , index)
index += struct.calcsize('>IIII')
for i in xrange(images):
image = Image.new('L', (columns, rows))
for x in xrange(rows):
for y in xrange(columns):
image.putpixel((y, x), int(struct.unpack_from('>B', buf, index)[0]))
index += struct.calcsize('>B')

print 'save ' + str(i) + 'image'
image.save('test/' + str(i) + '.png')

def read_label(filename, saveFilename):
f = open(filename, 'rb')
index = 0
buf = f.read()

f.close()

magic, labels = struct.unpack_from('>II' , buf , index)
index += struct.calcsize('>II')

labelArr = [0] * labels

for x in xrange(labels):
labelArr[x] = int(struct.unpack_from('>B', buf, index)[0])
index += struct.calcsize('>B')

save = open(saveFilename, 'w')

save.write(','.join(map(lambda x: str(x), labelArr)))
save.write('\n')

save.close()
print 'save labels success'

if __name__ == '__main__':
read_image('data/raw/t10k-images-idx3-ubyte')
read_label('data/raw/t10k-labels-idx1-ubyte', 'test/label.txt')

解析出的图片
解析出数据集中的图片后,我们可以随意选取图片来测试训练的模型model.ckpt
代码清单 demo.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
import torchvision.transforms as transforms
from model import ConvNet
from PIL import Image


model_path = "./model.ckpt"
img_path = "./12.png"

model = ConvNet()
print "load pretrained model from %s" % model_path
model.load_state_dict(torch.load(model_path))

transformer = transforms.ToTensor()
# 将图像转换为灰度模式,即单通道
image = Image.open(img_path).convert('L')
#image.resize((28, 28), Image.BILINEAR)
# 将图像转换为Tensor
image = transformer(image)
# 为Tensor添加一维,表示batch
image = image.view(1, *image.size())

model.eval()
output = model(image)

preds = torch.max(output, 1)[1]

print preds.item()
1
2
3
$ python demo.py
load pretrained model from ./model.ckpt
9

可以看到用训练的模型可以正确识别出图片中的数字9。
reference
https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/convolutional_neural_network/main.py