물리학과 인공지능이 만나면..

Momentum

Data

data = np.loadtxt('../../data/linear-regression/ex1data1.txt', delimiter=',')
X = data[:, 0].reshape(data[:, 0].shape[0], 1) # Population
Y = data[:, 1].reshape(data[:, 1].shape[0], 1) # profit

# Standardization
scaler_x = StandardScaler()
scaler_y = StandardScaler()
X = scaler_x.fit_transform(X)
Y = scaler_y.fit_transform(Y)

scatter(X, Y)
title('Profits distribution')
xlabel('Population of City in 10,000s')
ylabel('Profit in $10,000s')
grid()

Momentum Update

물리학적인 관점에서 optimization problem을 바라봤을때, convergence rate를 향상시킬수 있는 방법이 있습니다.
Loss를 산등성이에서 height로 바라봤을때 물리학적으로 potential energy를 갖을수 있게 됩니다.) (TODO:추가 연구 필요)

negative gradient는 양옆 가파른 산골짜기의 한쪽을 타고 조금씩 내려오는 형태와 유사하기 때문에 SGD는 convergence까지 매우 느립니다. (특히 초기 steep gain이후에 계속해서 더 느려짐)

\[\begin{align} v &= \gamma v_{t-1} + \eta \nabla_{\theta} J(\theta; x^{(i)},y^{(i)}) \\ \theta &= \theta - v \end{align}\]

위의 SGD예제처럼 gradient값이 직접적으로 weights에 영향을 주는것이 아니라, gradient값은 오직 velocity에 영향을 주게 됩니다.
그 뒤 velocity는 weights값에 영향을 미치게 됩니다.

코드에서 구현은 안되어 있지만 일반적으로 momentum은 0.5에서 시작해서 끝날때쯤에는 0.99까지 가게 만드는게 일반적입니다

w = np.array([-0.1941133,  -2.07505268])

def predict(w, X):
    N = len(X)
    yhat = w[1:].dot(X.T) + w[0]
    yhat = yhat.reshape(X.shape)
    return yhat

def momentum_nn(X, Y, w, eta=0.1, gamma=0.5):
    N = len(X)
    v = np.zeros(w.shape[1:])
    v_b = np.zeros(1)

    for i in range(N):
        x = X[i]
        y = Y[i]
        yhat = predict(w, x)
        delta = y - yhat

        v = gamma*v + 2/N * eta * np.sum(-delta.dot(x))
        v_b = gamma*v_b + 2/N * eta * np.sum(-delta)

        w[1:] = w[1:] - v
        w[0] = w[0] - v_b
    return w


for i in range(1, 10):
    w = momentum_nn(X, Y, w)
    yhat = predict(w, X)

    axes = subplot(3, 3, i)
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

    scatter(X, Y)
    plot(X, yhat, color='red')

    yhats = np.where(yhat >= 0.5, 1, 0)
    accuracy = mean_squared_error(Y, yhats)
    print('Mean Squared Error (less is good):', accuracy)

Mean Squared Error (less is good): 1.76244027984
Mean Squared Error (less is good): 1.0
Mean Squared Error (less is good): 1.0
Mean Squared Error (less is good): 0.725250038743
Mean Squared Error (less is good): 0.637070721578
Mean Squared Error (less is good): 0.637070721578
Mean Squared Error (less is good): 0.651252102667
Mean Squared Error (less is good): 0.651252102667
Mean Squared Error (less is good): 0.657019385081