for textmining

선형회귀 파라메터 추정

|

이번 글에서는 선형회귀 모델의 계수를 추정하는 방법을 살펴보도록 하겠습니다. 이번 글은 고려대 김성범 교수님 강의와 ‘밑바닥부터 시작하는 데이터과학(조엘 그루스 지음, 인사이트 펴냄)’을 정리하였음을 먼저 밝힙니다. 그럼 시작하겠습니다.

선형회귀

선형회귀(Multiple Linear Regression)는 수치형 설명변수 X와 연속형 숫자로 이뤄진 종속변수 Y간의 관계를 선형으로 가정하고 이를 가장 잘 표현할 수 있는 회귀계수를 데이터로부터 추정하는 모델입니다. 다음 그림처럼 집 크기와 가격의 관계를 나타내는 직선을 찾는 것이 선형회귀의 목표입니다.

독립변수들로 이뤄진 행렬 $X$와 종속변수 벡터 $Y$가 주어졌을 때 다중선형회귀 모델은 다음과 같이 정의됩니다.

Direct Solution

선형회귀의 계수들은 실제값과 모델 예측값의 차이, 즉 오차제곱합(error sum of squares)을 최소로 하는 값들입니다. 이를 만족하는 최적의 계수들은 회귀계수에 대해 미분한 식을 0으로 놓고 풀면 아래와 같이 명시적인 해를 구할 수 있습니다. 다시 말해 우리에게 주어진 $X$, $Y$ 데이터만 가지고 계수를 단번에 추정할 수 있다는 이야기입니다.

분석 대상 데이터

분석 대상 데이터는 ‘밑바닥부터 시작하는 데이터 과학’에서 제시된 예시데이터입니다. $X$는 친구수, 근무시간, 박사학위 취득여부이고, $Y$는 사용자가 웹사이트에서 보내는 시간(분)에 해당합니다. $X$ 가운데 친구 수 변수 하나만 떼어 $Y$와의 관계를 2차원으로 도시하면 아래 그림과 같습니다.

다음은 데이터입니다. $X$는 네 개 요소로 구성돼 있는데 각각 상수항, 친구수, 근무시간, 박사학위 취득 여부에 해당합니다.

# [1(beta_0), 친구수, 근무시간, 박사학위 취득 여부]
input = [[1,49,4,0],[1,41,9,0],[1,40,8,0],[1,25,6,0],[1,21,1,0],[1,21,0,0],[1,19,3,0],[1,19,0,0],[1,18,9,0],[1,18,8,0],[1,16,4,0],[1,15,3,0],[1,15,0,0],[1,15,2,0],[1,15,7,0],[1,14,0,0],[1,14,1,0],[1,13,1,0],[1,13,7,0],[1,13,4,0],[1,13,2,0],[1,12,5,0],[1,12,0,0],[1,11,9,0],[1,10,9,0],[1,10,1,0],[1,10,1,0],[1,10,7,0],[1,10,9,0],[1,10,1,0],[1,10,6,0],[1,10,6,0],[1,10,8,0],[1,10,10,0],[1,10,6,0],[1,10,0,0],[1,10,5,0],[1,10,3,0],[1,10,4,0],[1,9,9,0],[1,9,9,0],[1,9,0,0],[1,9,0,0],[1,9,6,0],[1,9,10,0],[1,9,8,0],[1,9,5,0],[1,9,2,0],[1,9,9,0],[1,9,10,0],[1,9,7,0],[1,9,2,0],[1,9,0,0],[1,9,4,0],[1,9,6,0],[1,9,4,0],[1,9,7,0],[1,8,3,0],[1,8,2,0],[1,8,4,0],[1,8,9,0],[1,8,2,0],[1,8,3,0],[1,8,5,0],[1,8,8,0],[1,8,0,0],[1,8,9,0],[1,8,10,0],[1,8,5,0],[1,8,5,0],[1,7,5,0],[1,7,5,0],[1,7,0,0],[1,7,2,0],[1,7,8,0],[1,7,10,0],[1,7,5,0],[1,7,3,0],[1,7,3,0],[1,7,6,0],[1,7,7,0],[1,7,7,0],[1,7,9,0],[1,7,3,0],[1,7,8,0],[1,6,4,0],[1,6,6,0],[1,6,4,0],[1,6,9,0],[1,6,0,0],[1,6,1,0],[1,6,4,0],[1,6,1,0],[1,6,0,0],[1,6,7,0],[1,6,0,0],[1,6,8,0],[1,6,4,0],[1,6,2,1],[1,6,1,1],[1,6,3,1],[1,6,6,1],[1,6,4,1],[1,6,4,1],[1,6,1,1],[1,6,3,1],[1,6,4,1],[1,5,1,1],[1,5,9,1],[1,5,4,1],[1,5,6,1],[1,5,4,1],[1,5,4,1],[1,5,10,1],[1,5,5,1],[1,5,2,1],[1,5,4,1],[1,5,4,1],[1,5,9,1],[1,5,3,1],[1,5,10,1],[1,5,2,1],[1,5,2,1],[1,5,9,1],[1,4,8,1],[1,4,6,1],[1,4,0,1],[1,4,10,1],[1,4,5,1],[1,4,10,1],[1,4,9,1],[1,4,1,1],[1,4,4,1],[1,4,4,1],[1,4,0,1],[1,4,3,1],[1,4,1,1],[1,4,3,1],[1,4,2,1],[1,4,4,1],[1,4,4,1],[1,4,8,1],[1,4,2,1],[1,4,4,1],[1,3,2,1],[1,3,6,1],[1,3,4,1],[1,3,7,1],[1,3,4,1],[1,3,1,1],[1,3,10,1],[1,3,3,1],[1,3,4,1],[1,3,7,1],[1,3,5,1],[1,3,6,1],[1,3,1,1],[1,3,6,1],[1,3,10,1],[1,3,2,1],[1,3,4,1],[1,3,2,1],[1,3,1,1],[1,3,5,1],[1,2,4,1],[1,2,2,1],[1,2,8,1],[1,2,3,1],[1,2,1,1],[1,2,9,1],[1,2,10,1],[1,2,9,1],[1,2,4,1],[1,2,5,1],[1,2,0,1],[1,2,9,1],[1,2,9,1],[1,2,0,1],[1,2,1,1],[1,2,1,1],[1,2,4,1],[1,1,0,1],[1,1,2,1],[1,1,2,1],[1,1,5,1],[1,1,3,1],[1,1,10,1],[1,1,6,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,4,1],[1,1,9,1],[1,1,9,1],[1,1,4,1],[1,1,2,1],[1,1,9,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,1,1],[1,1,1,1],[1,1,5,1]]
# 사이트에서 보내는 시간(분)
output = [68.77,51.25,52.08,38.36,44.54,57.13,51.4,41.42,31.22,34.76,54.01,38.79,47.59,49.1,27.66,41.03,36.73,48.65,28.12,46.62,35.57,32.98,35,26.07,23.77,39.73,40.57,31.65,31.21,36.32,20.45,21.93,26.02,27.34,23.49,46.94,30.5,33.8,24.23,21.4,27.94,32.24,40.57,25.07,19.42,22.39,18.42,46.96,23.72,26.41,26.97,36.76,40.32,35.02,29.47,30.2,31,38.11,38.18,36.31,21.03,30.86,36.07,28.66,29.08,37.28,15.28,24.17,22.31,30.17,25.53,19.85,35.37,44.6,17.23,13.47,26.33,35.02,32.09,24.81,19.33,28.77,24.26,31.98,25.73,24.86,16.28,34.51,15.23,39.72,40.8,26.06,35.76,34.76,16.13,44.04,18.03,19.65,32.62,35.59,39.43,14.18,35.24,40.13,41.82,35.45,36.07,43.67,24.61,20.9,21.9,18.79,27.61,27.21,26.61,29.77,20.59,27.53,13.82,33.2,25,33.1,36.65,18.63,14.87,22.2,36.81,25.53,24.62,26.25,18.21,28.08,19.42,29.79,32.8,35.99,28.32,27.79,35.88,29.06,36.28,14.1,36.63,37.49,26.9,18.58,38.48,24.48,18.95,33.55,14.24,29.04,32.51,25.63,22.22,19,32.73,15.16,13.9,27.2,32.01,29.27,33,13.74,20.42,27.32,18.23,35.35,28.48,9.08,24.62,20.12,35.26,19.92,31.02,16.49,12.16,30.7,31.22,34.65,13.13,27.51,33.2,31.57,14.1,33.42,17.44,10.12,24.42,9.82,23.39,30.93,15.03,21.67,31.09,33.29,22.61,26.89,23.48,8.38,27.81,32.35,23.84]

$i$번째 데이터 포인트에 대한 회귀식은 다음과 같습니다. 우리는 여기에서 회귀계수로 이뤄진 벡터 $β$를 추정해야 합니다.

$(X^TX)^{-1}X^Ty$를 계산해 데이터로부터 명시적인 해를 구한 결과 $β$는 [30.63, 0.97, -1.87, 0.91]로 추정되었습니다.

경사하강법 같은 반복적인 방식으로 선형회귀 계수를 구할 수도 있습니다. 경사하강법이란 어떤 함수값을 최소화하기 위해 임의의 시작점을 잡은 후 해당 지점에서의 그래디언트(경사)를 구하고, 그래디언트의 반대 방향으로 조금씩 이동하는 과정을 여러번 반복하는 것입니다. 예컨대 아래 그림(출처)과 같습니다.

이 글에서는 경사하강법 가운데 Stochastic Gradient Descent(SGD) 기법을 쓰겠습니다. SGD는 반복문을 돌 때마다 개별 데이터 포인트에 대한 그래디언트를 계산하고 이 그래디언트의 반대 방향으로 파라메터를 업데이트해 함수의 최소값을 구하는 기법입니다.

SGD 관련 메인 코드는 다음과 같습니다. 코드에서 ‘value’는 우리가 최소화하고 싶은 값으로 선형회귀 모델에서는 오차제곱합을 가리킵니다. ‘target_fn’은 목적함수로 오차제곱합을 아웃풋으로 산출하는 함수를 의미합니다. ‘theta’는 해당 목적함수의 파라메터인데요, 우리 문제에선 $α$와 $β$를 말합니다. ‘gradient_fn’은 각 파라메터에 대한 목적함수의 그래디언트를 가리킵니다.

def minimize_stochastic(target_fn, gradient_fn, x, y, theta_0, alpha_0=0.01):
    # SGD 방식으로 gradient descent
    # minimize_batch보다 훨씬 빠르다

    data = zip(x, y)
    # theta_0를 초기 theta로
    theta = theta_0
    # alpha_0를 초기 이동거리(step_size)로
    alpha = alpha_0
    # 시작할 때의 최소값
    min_theta, min_value = None, float("inf")
    iterations_with_no_improvement = 0

    # 만약 100번 넘게 반복하는 동안 value가 더 작아지지 않으면 멈춤
    while iterations_with_no_improvement < 100:
        value = sum( target_fn(x_i, y_i, theta) for x_i, y_i in data )

        # 새로운 최솟값을 찾았다면
        if value < min_value:
            # 이 값을 저장
            min_theta, min_value = theta, value
            # 100번 카운트도 초기화
            iterations_with_no_improvement = 0
            # 기본 이동거리로 돌아감
            alpha = alpha_0

        # 만약 최솟값이 줄어들지 않는다면
        else:
            # 이동거리 축소
            alpha *= 0.9
            # 100번 카운트에 1을 더함
            iterations_with_no_improvement += 1
		
        # 반복문이 돌 때마다 in_random_order를 호출하기 때문에
        # 매 iter마다 그래디언트를 계산하는 순서가 달라짐
        for x_i, y_i in in_random_order(data):
            # 각 데이터 포인트에 대해 그래디언트를 계산
            gradient_i = gradient_fn(x_i, y_i, theta)
            # 기존 theta에서, 학습률(alpha)과 그래디언트를 뺀 것을 업데이트
            theta = vector_subtract(theta, scalar_multiply(alpha, gradient_i))

    return min_theta

그러면 이제는 value와 target_fn, gradient_fn, theta를 정의해야 합니다. 우선 우리가 구해야 하는 파라메터는 입력변수 개수($k$=3)+상수항(1) 길이를 가진 벡터 $β$이므로 beta는 4차원 벡터로 선언했습니다.

import random
random.seed(10)
# 추정 대상 beta (vector)
# beta = [beta_0, beta_1, ..., beta_k]
beta = [random.random() for x_i in input[0]] 

우리가 최소화하고자 하는 값(value)은 오차제곱합입니다. $i$번째 데이터 포인트에 대한 오차제곱(Squared Error)은 다음과 같은 식으로 나타낼 수 있습니다.

이를 ‘squared_error’ 함수로 표현할 수 있습니다.

def squared_error(x_i, y_i, beta):
    return error(beta, x_i, y_i) ** 2

def error(beta, x_i, y_i):
    # 실제값 y_i와 예측값 사이의 편차
    return y_i - predict(beta, x_i)

def dot(v, w):
    """v_1 * w_1 + ... + v_n * w_n"""
    return sum(v_i * w_i for v_i, w_i in zip(v, w))

def predict(beta, x_i):
    # 현재 회귀계수들을 가지고 예측
    # 예측값 = x_i와 beta의 선형결합
    # x_i = [1(beta_0), x_i1, x_i2, ..., x_ik]
    # beta = [beta_0, beta_1, ..., beta_k]
    return dot(x_i, beta)

gradient_fn은 다음과 같습니다. 목적함수인 ‘squared_error’를 ‘theta’로 미분한 값입니다. 이를 식으로 정리하면 다음과 같습니다.

이를 코드로 나타내면 다음과 같습니다.

def squared_error_gradient(x_i, y_i, beta):
    # i번째 오류제곱 값의 beta에 대한 편미분값
    return [-2 * x_ij * error(beta, x_i, y_i) for x_ij in x_i]

이밖에 ‘minimize_stochastic’ 구동에 필요한 함수도 정의하겠습니다.

def in_random_order(data):
    # Stochastic Gradient Descent 수행을 위한 함수로,
    # 한번 반복문을 돌 때마다 임의의 순서로 데이터 포인트를 반환
    # 데이터 포인트의 인덱스를 list로 생성
    indexes = [i for i, _ in enumerate(data)]
    # 이 인덱스를 랜덤하게 섞는다
    random.shuffle(indexes)
    # 이 순서대로 데이터를 반환한다
    for i in indexes:
        yield data[i]

def vector_subtract(v, w):
    """subtracts two vectors componentwise"""
    return [v_i - w_i for v_i, w_i in zip(v,w)]

def scalar_multiply(c, v):
    return [c * v_i for v_i in v]

마지막으로 코드 전체를 구동하는 명령문은 다음과 같습니다.

beta = minimize_stochastic(target_fn=squared_error,
                           gradient_fn=squared_error_gradient,
                           x=input,
                           y=output,
                           theta_0=beta,
                           alpha_0=0.0001)

SGD 기법으로 해를 구한 결과 $β$는 [30.55184071133236, 0.973395629801253, -1.8632386392995979, 0.9471533590851985]로 추정되었습니다. 이는 명시적 해와 유사합니다.

회귀계수를 얼마나 신뢰할 수 있나

데이터로부터 추정한 회귀계수는 진짜 회귀계수라고 할 수는 없습니다. 어디까지나 일부 데이터를 가지고 도출된 계수일 뿐더러 데이터에 노이즈가 끼어있을 수도 있기 때문이죠. 선형회귀 모델의 $j$번째 독립변수에 대한 추정 회귀계수 $β_j$를 추정 회귀계수의 표준편차 $σ_j$로 나눈 값($t_j$)은 $n-k$의 자유도를 지닌 $t$분포를 따른다는 사실이 알려져 있습니다.

그렇다면 회귀계수의 표준편차는 어떻게 구할까요? 이 때 쓰는 것이 bootstap입니다. 기존 데이터에서 중복을 허용된 재추출을 통해 새로운 데이터를 만들어내는 방법입니다. 우리가 갖고 있는 학습데이터에 bootstrap을 적용하여 여러 새로운 데이터를 생성하고, 이들 각각으로부터 추정 회귀계수를 도출할 수 있습니다.

예를 들어 특정 독립변수에 해당하는 회귀계수가 부트스트랩 데이터마다 크게 달라지지 않는다면 해당 회귀계수는 상당히 신뢰할 수 있을 겁니다. 반대로 계수가 크게 변한다면 추정된 계수는 신뢰할 수 없을 겁니다.

부트스트랩 데이터 개수만큼 회귀계수 추정값을 구하여 회귀계수의 표준편차 $σ_j$를 구한 뒤 이를 원래 데이터에서 구한 추정회귀 계수 $β_j$에 나눠줘 $t_j$를 계산합니다. 이 $t_j$가 $t$분포 상에서 어느 위치에 해당하는지를 따져서 진짜 회귀계수가 속해 있을 수 있는 신뢰구간을 구할 수 있게 됩니다.

부트스트랩 데이터 각각에 대해 회귀계수를 먼저 구해보겠습니다. 우선 데이터로부터 회귀계수를 추정하는 함수를 만들었습니다.

def estimate_beta(x, y):
    # 추정 대상 beta (vector)
    beta = [random.random() for _ in x[0]] 
    return minimize_stochastic(target_fn=squared_error,
                           gradient_fn=squared_error_gradient,
                               x=x,
                               y=y,
                               theta_0=beta,
                               alpha_0=0.0001)

이제는 부트스트랩 함수를 만들 차례입니다.

def bootstrap_sample(data):
    # 데이터 개수만큼 data로부터 무작위 재추출 (중복 허용)
    return [random.choice(data) for _ in data]

def bootstrap_statistic(data, stats_fn, num_samples):
    # num_sample개의 bootstrap 샘플에 대해 stats_fn을 적용
    return [stats_fn(bootstrap_sample(data))
            for _ in range(num_samples)]

def estimate_sample_beta(sample):
    # sample은 (x_i, y_i)로 구성된 리스트
    x_sample, y_sample = zip(*sample)
    return estimate_beta(x_sample, y_sample)

이제 함수를 실행해 봅시다.

bootstrap_betas = bootstrap_statistic(zip(input, output),
                                      estimate_sample_beta,
                                      100)

함수를 실행하면 100개의 부트스트랩 데이터에 대해 추정된 회귀계수의 쌍이 100개가 도출됩니다. 다음과 같습니다.

[[30.343293717263954, 0.955828442766452, -1.8249710660990814, 0.19262090720112343], 
[31.653090581087664, 0.8930047770677926, -1.8274769383940832, 0.003915653745906352], 
[30.062130836429326, 0.9735595051657044, -1.7409587831575586, -0.00993645069080973], 
[30.459536927917853, 0.9275688208209533, -1.970156796143431, 1.4162870885210503], 
[29.965245738566743, 0.9987255912413002, -1.8469983707870146, 1.2570210865880311], 
[30.688436939427994, 0.949721374789511, -1.791565743167013, 0.8490708882537681], 
[27.389069317384713, 1.1899927371032066, -1.6395966566222984, 2.201209799844561], 
[31.578900415746247, 0.963470036264341, -2.061237515990318, 0.6044582651703277], 
[29.9969538479282, 1.0563305743635187, -2.0370138317358597, 1.395128459047288], 
[30.71129774327976, 0.9412707718288738, -1.9049341856989015, 0.9975476331019565], 
[28.95704337705777, 1.0372235777797778, -1.7374326711464978, 2.017330974451086], 
[29.347442580389554, 0.9691244573476897, -1.6980453130961777, 1.758299750670592], 
[30.89303821117151, 0.9643400467987878, -2.0014985101424188, 0.9437323561405635], 
[31.033963018013235, 0.944512876776604, -1.8339869584884239, 1.3762002169288328], 
[29.7724071122711, 0.9637788276539256, -1.8712203112136903, 2.31343127588546], 
[30.712161312804817, 0.9787589052763698, -1.9637325222425992, 1.1877691857117059], 
[28.882828505113785, 1.0615985715436485, -1.8322410460473135, 2.282737207651317], 
[31.164357855104715, 0.9313158660935656, -1.9362632274947122, 1.190519238545326], 
[31.02991081861022, 0.9646849000104669, -2.048540894739888, 1.4641504377576209], 
[31.270216131821417, 0.9412983344101367, -1.9754923012624737, 0.9244505522676154], 
[28.972811257580727, 1.0626659243331988, -1.8172782368438227, 1.9772010790265988], 
[32.47326463848947, 0.9446160353528085, -1.8827511459492636, -0.555442587406212], 
[31.92477575189198, 0.9076629023019004, -1.833715795651844, 0.2543582162753717], 
[30.73653173338487, 0.9798526580610298, -1.8368633847870894, 1.3991627628282644], 
[31.052387396870795, 0.9465884318410379, -1.937853405314178, 1.6167719138597119], 
[30.838191991050294, 0.9547770115578801, -1.9540199662139435, 1.1143266882659546], 
[30.989458366759308, 0.9112712446387848, -1.8275575068998577, 1.1518607341417386], 
[31.499881551689803, 0.984790498881108, -2.017772833184302, 0.6959781050813002], 
[30.6294509899701, 0.9532711074826222, -1.9002085912683446, 0.41904632258007973], 
[31.871492871773135, 0.9640564217099683, -2.0592781960142137, -0.1325017160959066], 
[30.979163018183787, 0.9442339015956216, -1.9852723070949032, 1.018733620546921], 
[30.94960897962424, 0.9917550347995638, -1.879826427606344, -0.9344395120511418], 
[29.974194993513784, 0.9372566526240828, -1.6983466191579584, 1.413895309854322], 
[30.48871554272555, 0.9274603561875837, -1.6730463683580687, 0.7292989104575046], 
[29.976667988625856, 0.9711827384011987, -1.839308476190981, 1.9172708899219006], 
[30.692522310161685, 0.951972668383523, -1.7681190534968108, 0.4691416207459495], 
[28.288129698240517, 1.0948205153184756, -1.8171117444664553, 3.3038823592912854], 
[29.822963056351064, 1.0327653743440959, -1.869836637120832, 1.5486447223042832], 
[30.386808122240165, 0.9779154202444837, -1.8976774910373846, 2.1086881935517914], 
[30.741619085348518, 0.9217376178433008, -1.7198893947497884, -0.34894837617535585], 
[29.12768874650376, 1.0267526931238684, -1.7647591340646893, 1.6875260877074156], 
[30.24960760356648, 1.0121231682791874, -1.7376567600056418, 0.296763706300862], 
[31.035301020344654, 0.9431366335375586, -1.953384748218975, 1.1579872338326034], 
[30.5768817821496, 0.947021358453022, -1.7205109807167238, 0.783267985369746], 
[29.374898109259238, 1.0939268639920088, -1.6424719680107713, 1.140410049137131], 
[29.40635067130438, 1.0637366346977501, -1.874508187809908, 2.4511133687480937], 
[30.95063057843148, 0.936325594647627, -1.7737167596585748, 1.0343706367628467], 
[31.678274244399336, 0.9227541184400466, -1.9918345207726629, 0.5086225151581116], 
[29.98166257475079, 0.9605170287193446, -1.7577319419624025, 0.5120422638922413], 
[31.323616890901306, 0.8176945107731992, -1.6500326316218685, -0.8683616760822116], 
[30.302149825423008, 0.9737686518258278, -1.8128215160225314, 0.7033976674704967], 
[31.386502073193892, 0.882658769388189, -1.6825286945756024, -0.44420246297894167], 
[30.323858845499185, 1.0757705725166067, -2.071948018422145, 2.2968801639230105], 
[30.494825872775092, 1.0108387224084352, -1.9188457540102233, 0.8717781457244161], 
[30.576615201544183, 0.9120017225776094, -1.7251677135358494, 1.7832767592542957], 
[28.492384747280074, 1.1059593661729585, -1.7735172156402836, 2.663999434050509], 
[30.104972970041402, 1.0840732600063667, -1.9613608117473043, 1.5315485478845854], 
[31.757239201637674, 0.9596734272864389, -2.051685205450117, -0.2098820060083817], 
[29.854863672046072, 1.0209262822106175, -1.9439212461621536, 1.5185447249181654], 
[29.41528261809298, 1.1023172226932805, -1.891600842305552, 2.2231352104855175], 
[29.11300285931626, 1.0343057536723956, -1.7718637081275448, 1.6750740838338922], 
[29.39039719035395, 0.9297526750567007, -1.7896140395445401, 1.940908892848941], 
[27.962469288747435, 1.1422689401264716, -1.8893543588108115, 3.388049453882638], 
[29.099157058356646, 1.1348583693290986, -1.8317976244428604, 1.663037801381304], 
[29.137612379810427, 1.1300566948876725, -1.7913898643267325, 1.6482641446500326], 
[31.130222918185655, 0.9030219758748801, -1.9131692904927817, 1.474759950650903], 
[30.58125248343542, 0.9223235418040419, -1.6101280249243748, 1.0127434495055474], 
[30.891973319804126, 0.9193557426060519, -1.8375241356871626, 0.08461128745246431], 
[31.14771990918551, 0.9553305770320748, -1.878558304343514, 0.9002047428943074], 
[31.27678282839363, 0.9335965433539816, -1.9467073427528274, 0.9226402524279342], 
[29.56185170659549, 1.077105492914354, -1.8241298785196545, 1.595589221533047], 
[28.076434296423237, 1.0481641153277417, -1.6920455897818678, 1.8737381674617253], 
[28.90380434686726, 1.0604153971116186, -1.7365915272861396, 2.026884447765734], 
[29.99741704319044, 1.0666687600651432, -1.9111028667834695, 1.2256669988576885], 
[29.43736242616701, 1.0472064322799974, -1.8063134614295004, 1.9580133762596552], 
[29.227650225513607, 1.0144255050958706, -1.8016989061013566, 1.7520195371651819], 
[29.243279432899033, 0.9968822418447205, -1.778030031899557, 0.5551593556261414], 
[31.46888932259628, 0.9105043256707787, -1.7449900371478968, -0.7085272038815223], 
[29.924868298448885, 1.0239101298478512, -1.8708758584829717, 0.959971954552992], 
[30.88467060799466, 0.9380489354101933, -1.8289315911516493, 0.36156696066090643], 
[31.757648253366902, 0.96456365074318, -2.12281082826505, 0.4933404740397043], 
[30.359636241287728, 0.9930935207021235, -1.91614726756068, 0.11981014378409671], 
[29.689503027235425, 1.0275941570260614, -1.8684278630889204, 1.695936319407561], 
[30.675188326985708, 1.045569865938166, -1.9146598992756043, 1.0011598992584545], 
[30.43186544380135, 1.0040689827077782, -2.0168868872700227, 1.3592129670325508], 
[31.56317410958609, 0.9256366346831482, -1.8198369687354612, 0.37814324357003987], 
[32.0724948565735, 0.9134883775996528, -1.9211327000011957, -1.1653914277078095], 
[29.24454398282765, 0.9972606107989663, -1.7192410689836413, 1.0773787739745337], 
[29.565133763140803, 0.969925575012797, -1.6784388902600302, 1.16934425282509], 
[30.521873388440795, 0.9659927881039246, -1.81299769918763, 0.9066914988998109], 
[31.304315710353873, 0.9388059361772737, -1.7375886505932063, 0.22777427755049626], 
[30.485899853898623, 0.9831955516685744, -1.8216915684007813, 1.4665292532815406], 
[30.586128743092342, 0.954830980894403, -1.9606989544282951, 2.7599642168405083], 
[30.331607238578425, 0.9800674077114176, -1.9621069206081838, 2.229119033504103], 
[31.372326222839966, 1.0047554827884049, -2.0193398364534048, 0.9477139083724562], 
[30.231048927885226, 1.0141246515657603, -1.8918513233859755, 0.1924866352716917], 
[30.263326574321635, 0.9311507918160458, -1.9097048979881397, 2.9233469946337465], 
[31.515503088339216, 0.9178784037775057, -1.998991069549314, 1.2935446029780655], 
[30.03038991481507, 1.035274803835536, -2.028483027984759, 2.544621079119659], 
[31.070081646523697, 0.927782273878715, -1.9091320174841293, 0.15186481410364247]]

우리는 이로부터 회귀계수 각각의 표준편차를 구할 수 있습니다. SGD 기법으로 해를 구한 결과 $β$는 [30.55184071133236, 0.973395629801253, -1.8632386392995979, 0.9471533590851985]로 추정됐다는 사실로부터 $t_j$ 또한 구할 수 있습니다. 이로부터 SGD로 구한 $β$가 얼마나 믿을 만한지 추론해낼 수 있습니다.

정규화

Regulization은 회귀계수에 제약을 가해 일반화 성능을 높이는 기법입니다. 정규화와 관련해서는 이곳을 참고하시면 좋을 것 같습니다. 이 가운데 $β$의 L2 norm을 제한하는 릿지 회귀를 살펴보겠습니다. 파이썬 코드는 다음과 같습니다.

def ridge_penalty(beta, alpha):
    # alpha는 페널티의 강도를 조절하는 하이퍼 파라메터
    # ridge회귀는 상수에 대한 패널티는 주지 않는다
    return alpha * dot(beta[1:], beta[1:])

def squared_error_ridge(x_i, y_i, beta, alpha):
    # beta를 사용할 때 오류와 페널티의 합을 반환
    return error(beta, x_i, y_i) ** 2 + 
			ridge_penalty(beta, alpha)

def ridge_penalty_gradient(beta, alpha):
    # ridge회귀는 상수에 대한 패널티는 주지 않는다
    return [0] + [2 * alpha * beta_j for beta_j in beta[1:]]

def vector_add(v, w):
    """adds two vectors componentwise"""
    return [v_i + w_i for v_i, w_i in zip(v,w)]

def squared_error_ridge_gradient(x_i, y_i, beta, alpha):
    # i번째 오류 제곱값과 패널티 합의 기울기
    return vector_add(squared_error_gradient(x_i, y_i, beta),
                      ridge_penalty_gradient(beta, alpha))

def estimate_beta_ridge(x, y, alpha):
    # 페널티 파라메터가 alpha인 ridge 회귀를 SGD로 학습
    from functools import partial
    beta_initial = [random.random() for _ in x[0]]
    return minimize_stochastic(partial(squared_error_ridge, 
                                       alpha=alpha),
                               partial(squared_error_ridge_gradient, 
                                       alpha=alpha),
                               x, y,
                               beta_initial,
                               0.001)

alpha=0으로 실행하면 기존 선형회귀와 동일한 결과가 나옵니다. alpha를 증가시킬 수록 $β$가 작아집니다.

random.seed(0)
beta_0 = estimate_beta_ridge(input, output, alpha=0.0)
beta_10 = estimate_beta_ridge(input, output, alpha=10.0)

alpha=10으로 실행하면 $β$는 [28.30817306438882, 0.7309844552226781, -0.9146285684215772, -0.01693934899423252]으로 추정되는데요. 0에 가까운 네번째 회귀계수에 해당하는 변수는 ‘박사학위 취득 여부’입니다. 다시 말해 박사학위 취득 여부는 사이트 이용시간에 큰 영향을 주는 변수가 아니라는 걸 알 수 있습니다.

Comments