【tensorflow 全连接神经网络】 minist 手写数字识别

主要内容:
使用tensorflow构建一个三层全连接传统神经网络,作为字符识别的多分类器。通过字符图片预测对应的数字,对mnist数据集进行预测。

# coding: utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import math

mnist = input_data.read_data_sets("./mnist/",one_hot=True)

print("Training set:",mnist.train.images.shape)
print("Training set labels:",mnist.train.labels.shape)
print("Dev Set(Cross Validation set):",mnist.validation.images.shape)
print("Dev Set labels:",mnist.validation.labels.shape)
print("Test Set:",mnist.test.images.shape)
print("Test Set labels:",mnist.test.labels.shape)

x_train = mnist.train.images
y_train = mnist.train.labels
x_dev = mnist.validation.images
y_dev = mnist.validation.labels
x_test = mnist.test.images
y_test = mnist.test.labels


def display_digit(index):
    print(y_train[index])
    label = y_train[index].argmax(axis=0)
    image = x_train[index].reshape([28,28])
    plt.title("Example: %d  Label: %d" % (index, label))
    plt.imshow(image, cmap=plt.get_cmap("gray_r"))
    plt.show()

display_digit(5)
print(y_train[5].shape)


#按照Andrew的建议把样本横向排列
x_train = x_train.T
y_train = y_train.T
x_dev = x_dev.T
y_dev = y_dev.T
x_test = x_test.T
y_test = y_test.T
print("x_train shape",x_train.shape)
print("y_train shape",y_train.shape)


def random_mini_batches(X,Y,mini_batch_size=64):
    """
    Creates a list of random minibatches from (X, Y)

    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (1 for blue dot / 0 for red dot), of shape (1, number of examples)
    mini_batch_size -- size of the mini-batches, integer

    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    m = X.shape[1] #训练样本个数
    mini_batches = []
    # Step 1: Shuffle (X, Y)
    permutation = list(np.random.permutation(m))
    shuffled_X = X[:, permutation]
    shuffled_Y = Y[:, permutation].reshape((-1, m))

    # Step 2: Partition (shuffled_X, shuffled_Y). Minus the end case.
    num_complete_minibatches = math.floor(
        m / mini_batch_size)  # number of mini batches of size mini_batch_size in your partitionning
    for k in range(0, num_complete_minibatches):
        mini_batch_X = shuffled_X[:, k * mini_batch_size:(k + 1) * mini_batch_size]
        mini_batch_Y = shuffled_Y[:, k * mini_batch_size:(k + 1) * mini_batch_size]

        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)

    # Handling the end case (last mini-batch < mini_batch_size)
    if m % mini_batch_size != 0:
        mini_batch_X = shuffled_X[:, mini_batch_size * num_complete_minibatches:]
        mini_batch_Y = shuffled_Y[:, mini_batch_size * num_complete_minibatches:]

        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)

    return mini_batches

"参数初始化"
layer_dims = [784,64,128,10] #三层网络,hidden units个数为64,128,10   一共有10个类别  

def init_parameters(layer_dims):
    parameters = {}
    L = len(layer_dims) - 1 # number of layers in the network
    for l in range(1,L+1):
        parameters["W"+str(l)] = tf.Variable(tf.random_normal([layer_dims[l], layer_dims[l-1]]))
        parameters["b"+str(l)] = tf.Variable(tf.random_normal([layer_dims[l],1]))
    return parameters    

def forward_propagation(X, parameters):
    W1 = parameters['W1']
    b1 = parameters['b1']
    W2 = parameters['W2']
    b2 = parameters['b2']
    W3 = parameters['W3']
    b3 = parameters['b3']

    Z1 = tf.add(tf.matmul(W1, X), b1)  # Z1 = np.dot(W1, X) + b1
    A1 = tf.nn.relu(Z1)  # A1 = relu(Z1)
    Z2 = tf.add(tf.matmul(W2, A1), b2)  # Z2 = np.dot(W2, a1) + b2
    A2 = tf.nn.relu(Z2)  # A2 = relu(Z2)
    Z3 = tf.add(tf.matmul(W3, A2), b3)  # Z3 = np.dot(W3,Z2) + b3

    return Z3

def compute_cost(Z3, Y):
    """
    Computes the cost

    Arguments:
    Z3 -- output of forward propagation (output of the last LINEAR unit), of shape (10, number of examples)
    Y -- "true" labels vector placeholder, same shape as Z3

    Returns:
    cost - Tensor of the cost function
    """

    # to fit the tensorflow requirement for tf.nn.softmax_cross_entropy_with_logits(...,...)
    logits = tf.transpose(Z3)
    labels = tf.transpose(Y)

    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
    return cost

def tf_nn_model(X_train,Y_train,X_test,Y_test,layer_dims,learning_rate=0.001,num_epochs=100,minibatch_size=64,print_cost=True):
    (n_x,m) = X_train.shape # (n_x: input size, m : number of examples in the train set)
    n_y = Y_train.shape[0] # n_y : output size
    costs = [] # to keep track of the cost
    X = tf.placeholder(tf.float32, [n_x, None], name="X")
    Y = tf.placeholder(tf.float32, [n_y, None], name="Y")
    parameters = init_parameters(layer_dims)
    Z3 = forward_propagation(X, parameters)
    cost = compute_cost(Z3, Y)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
    init  = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(num_epochs):
            epoch_cost = 0.  # Defines a cost related to an epoch
            num_minibatches = int(m / minibatch_size)
            minibatches = random_mini_batches(X_train, Y_train, minibatch_size)
            for minibatch in minibatches:
                (minibatch_X, minibatch_Y) = minibatch
                _, minibatch_cost = sess.run([optimizer, cost], feed_dict={X: minibatch_X, Y: minibatch_Y})
                epoch_cost += minibatch_cost / num_minibatches
            if print_cost == True and epoch % 10 == 0:
                print("Cost after epoch %i: %f" % (epoch, epoch_cost))
            if print_cost == True and epoch % 5 == 0:
                costs.append(epoch_cost)   
        # plot the cost
        plt.plot(np.squeeze(costs))
        plt.ylabel('cost')
        plt.xlabel('iterations (per tens)')
        plt.title("Learning rate =" + str(learning_rate))
        plt.show() 

        parameters = sess.run(parameters)
        print("Parameters have been trained!")

        correct_prediction = tf.equal(tf.argmax(Z3), tf.argmax(Y))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        print("Train Accuracy:", accuracy.eval({X: X_train, Y: Y_train}))
        print("Test Accuracy:", accuracy.eval({X: X_test, Y: Y_test}))

        return parameters


tf_nn_model(x_train,y_train,x_test,y_test,layer_dims,learning_rate=0.001,num_epochs=100,minibatch_size=64,print_cost=True)


运行结果:

Cost after epoch 0: 75.913229
Cost after epoch 10: 1.541095
Cost after epoch 20: 0.436585
Cost after epoch 30: 0.174160
Cost after epoch 40: 0.090298
Cost after epoch 50: 0.064457
Cost after epoch 60: 0.044082
Cost after epoch 70: 0.035504
Cost after epoch 80: 0.022698
Cost after epoch 90: 0.023649

Parameters have been trained!
Train Accuracy: 0.994545
Test Accuracy: 0.9427
Out[106]:
{'W1': array([[ 0.2372188 ,  1.27198195, -0.6455391 , ...,  1.26290512,
         -0.69059598,  0.36647785],
        [-0.50644702, -0.74370074,  0.38941762, ..., -0.15578361,
         -0.31009915, -0.17434931],
        [-2.5437634 ,  0.44527429, -0.70932513, ..., -1.01713133,
         -0.14752612,  0.19787782],
        ..., 
        [ 3.25048923,  0.08093037,  0.77567875, ..., -0.79534328,
          1.43014407,  0.21873565],
        [-1.93292856, -0.19783179,  0.12327723, ..., -0.22539552,
          0.13556184,  0.87210643],
        [-0.93210453,  0.2583403 ,  1.58626533, ..., -1.69557643,
          0.31096032,  0.41782433]], dtype=float32),
 'W2': array([[ 0.66262263, -0.41401526,  0.83104825, ..., -0.28790367,
          1.44923198, -0.01293663],
        [-0.94457793, -0.47847596,  0.39193049, ..., -0.44852871,
          0.31511024, -0.12879851],
        [ 0.83933985, -0.25525221,  1.83002853, ..., -0.7023285 ,
          0.29116887,  1.32396758],
        ..., 
        [-1.21769059,  0.21980943,  0.05707775, ..., -0.70724338,
          0.13368286, -0.47907224],
        [-0.78505909, -0.26749918, -1.0756464 , ...,  0.10546964,
          0.59970111, -0.47928923],
        [ 1.57277954,  0.20598291, -0.38545936, ..., -0.68153149,
         -0.01901394, -1.09839475]], dtype=float32),
 'W3': array([[ 0.23412205,  1.4664923 ,  1.02762878, ...,  0.13184339,
          1.05118167, -0.00358887],
        [ 0.26813394,  0.295957  ,  1.49240541, ...,  0.82661223,
          0.67465705, -0.32320595],
        [ 1.19123352, -0.83540916,  0.07576221, ..., -0.58284307,
          0.32790881,  0.13413283],
        ..., 
        [ 0.43964136,  1.74946868, -0.54555362, ..., -0.1613521 ,
         -0.37434128,  0.80795258],
        [ 0.60402709,  0.05262127,  0.42084417, ...,  0.47054997,
         -0.32987207, -1.64671504],
        [-0.78972542,  0.7970084 , -0.60551286, ...,  1.74413514,
          0.6057446 , -0.28617254]], dtype=float32),
 'b1': array([[-0.4571954 ],
        [-0.30936778],
        [-0.83330458],
        [-1.68725026],
        [-1.42897224],
        [-1.04096746],
        [-0.54966289],
        [ 2.43672371],
        [ 1.36083376],
        [-1.51412904],
        [-2.0457561 ],
        [-2.69589877],
        [-0.23028924],
        [ 0.88664472],
        [-1.48165977],
        [-2.08099437],
        [ 0.43034646],
        [ 0.7627002 ],
        [ 0.40478835],
        [-0.51313281],
        [-1.18395376],
        [-0.36716571],
        [-1.98513615],
        [-0.58582592],
        [-0.77087468],
        [-0.9414832 ],
        [ 0.25200051],
        [-0.98766547],
        [ 0.31909475],
        [ 0.0800764 ],
        [-0.01556224],
        [ 0.83097136],
        [ 0.32423681],
        [ 1.24688494],
        [-0.02111918],
        [-2.12303662],
        [-1.69796181],
        [ 0.68959635],
        [-0.6191389 ],
        [-1.28080022],
        [-0.17510706],
        [-0.23040138],
        [-0.46036553],
        [ 1.56836855],
        [ 2.0383904 ],
        [-0.86711407],
        [-1.19858789],
        [-1.96049547],
        [ 1.14845157],
        [-0.75677299],
        [-2.4980433 ],
        [ 0.13432245],
        [ 0.24774934],
        [-0.10357552],
        [ 0.93644065],
        [-1.22094846],
        [ 1.15299678],
        [ 1.51815248],
        [-0.20407377],
        [-0.76557356],
        [ 0.5967567 ],
        [ 1.13081288],
        [-0.34519741],
        [-0.18847673]], dtype=float32),
 'b2': array([[ 0.28188977],
        [ 1.13188219],
        [-0.51833898],
        [ 1.55272174],
        [ 0.3362346 ],
        [-0.62963486],
        [-0.55736727],
        [-1.99950421],
        [ 1.64439845],
        [ 0.09734726],
        [-2.69561672],
        [ 0.29041779],
        [ 0.72709852],
        [ 0.43301356],
        [-0.43779549],
        [-0.6581856 ],
        [-2.80175161],
        [-0.41372192],
        [-2.09087038],
        [-0.47786576],
        [ 0.31763604],
        [ 1.85912359],
        [ 1.59187448],
        [-1.36818421],
        [-0.65758836],
        [-0.12403597],
        [ 1.05362165],
        [-0.30393735],
        [ 1.8399303 ],
        [-0.29227388],
        [ 0.75677097],
        [ 0.3613534 ],
        [-0.18842472],
        [-0.66885817],
        [-0.27949655],
        [-0.89438319],
        [-1.51220632],
        [ 0.93994361],
        [-1.54467905],
        [-1.00363708],
        [-0.57895792],
        [-0.52491599],
        [ 2.27655602],
        [-0.85130656],
        [ 0.04630496],
        [ 1.12568331],
        [-0.38881832],
        [-0.27415273],
        [-0.86503613],
        [ 0.96864253],
        [-0.9870069 ],
        [ 0.37869945],
        [-1.68591571],
        [-0.62210619],
        [-0.01916602],
        [ 0.11517724],
        [-0.29602063],
        [-1.42557037],
        [ 1.11371112],
        [-1.10030782],
        [-0.23480549],
        [-0.83260995],
        [ 0.78863978],
        [-0.44784972],
        [ 0.18259326],
        [ 1.48195684],
        [-0.32906139],
        [-1.4134475 ],
        [ 0.52768463],
        [-0.46708786],
        [-1.52612662],
        [ 0.30641365],
        [-1.06699479],
        [-1.44061339],
        [-1.39849806],
        [-0.65535295],
        [-0.17019601],
        [ 0.86427599],
        [ 0.51089519],
        [ 0.63639545],
        [-0.31796476],
        [-0.96631444],
        [-1.21334612],
        [ 0.79893589],
        [ 0.90393507],
        [ 1.05157661],
        [-0.1798792 ],
        [ 0.35506439],
        [-0.88265395],
        [-0.77211195],
        [-0.35244057],
        [-0.97597492],
        [ 1.81438792],
        [ 1.50866187],
        [ 1.76945257],
        [-2.2490623 ],
        [ 1.27219939],
        [ 0.11137661],
        [-0.03369612],
        [ 1.64185321],
        [ 0.14421514],
        [ 1.1957972 ],
        [ 0.10298974],
        [-1.63592625],
        [ 1.57520294],
        [-2.0683074 ],
        [-0.78121209],
        [-0.02082653],
        [ 0.88429558],
        [ 0.98407972],
        [-1.09006429],
        [ 0.44493109],
        [-1.88774467],
        [-2.0510056 ],
        [-1.04833782],
        [ 1.08415902],
        [-1.55531442],
        [-1.52134264],
        [ 0.23356596],
        [-0.70101881],
        [-0.25792068],
        [ 0.41581729],
        [-0.11349884],
        [-3.29242682],
        [-0.68287402],
        [ 1.45735371],
        [ 0.07658232],
        [-0.82881683]], dtype=float32),
 'b3': array([[ 0.99828368],
        [-0.78877753],
        [-1.29528141],
        [-1.95668292],
        [ 1.43690228],
        [-0.19944769],
        [ 1.00068772],
        [ 0.8051874 ],
        [ 0.80680549],
        [ 0.26735926]], dtype=float32)}
In [ ]:

​
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页