요번에 제가 공부한거는 영상 화질 개선 쪽 관심을 갖게되어
한번 공부 하고 다른곳에 접목 시켜봐야겟다라는 취지로 공부하게 되엇고
이 공부를 하면서 처음으로 난관이엇던게 다른 딥러닝 과정과 다르게 이미지 처리 과정인 전처리 과정에서 꽤 애먹엇습니다
왜냐면 patch 사이즈가 잇는데 그것에 조건이 맞아야 되고 그 patch가 과연 어떤 저화질인가 원본이가 등등 생각을 하면서
알게된게 원본사진을 저화질 사진을 만들고 그리고 이미지를 patch 사이즈만큼 크롭 합니다 그리고 그 크롭된 이미지 바탕으로 학습시키는 것엿습니다
전 optimizer를 SGD 모멘트를 사용하엿고 단순한 SRCNN모델 구조 9-1-5 과정을 하엿습니다
변환 >>>
코드를 보시면
필요한 패키지들 import
import os
import numpy as np
from scipy import misc
import glob
import scipy.misc
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.ndimage
import cv2
학습데이터셋 정리
path = 'Train/'
dir_path = os.path.join(os.getcwd(), path) # 현재위치 주소를 가져온다 절대경로로
image_paths = glob.glob(os.path.join(dir_path, '*.bmp')) # glob 메소드를 활용해 bmp 파일을 전부 가져온다
patch 정보
I= 33
L= 21
stride = 21
scale =3
이미지 정보 사진으로 출력해보기
데이터 전처리 과정
inputs = []
labels = []
for path in image_paths:
image =scipy.misc.imread(path).astype(np.int)
h,w,c=image.shape
h = h - np.mod(h,3)
w = w - np.mod(w,3)
image = image[:h, :w]
label = image/255.0
inp=cv2.GaussianBlur(label,(15,15),0)
sub_inputs = []
sub_labels = []
h, w = inp.shape[0], inp.shape[1]
offset = abs(I - L)//2
for hh in range(0, h-I+1, stride):
for ww in range(0, w-I+1, stride):
sub_input = inp[hh:hh+I, ww:ww+I]
sub_label = label[hh+offset:hh+offset+L, ww+offset:ww+offset+L]
sub_input = sub_input.reshape(I, I, 3)
sub_label = sub_label.reshape(L, L, 3)
sub_inputs.append(sub_input)
sub_labels.append(sub_label)
inputs += sub_inputs
labels += sub_labels
inputs = np.asarray(inputs) # shape (N, I, I, 1)
labels = np.asarray(labels) # shape (N, L, L, 1)
import torch
from torch.utils.data import Dataset, DataLoader
torch 데이터셋을 만들기위한 과정
class SRdataset(Dataset):
def __init__(self):
self.inputs,self.labels = inputs,labels
def __len__(self):
return self.inputs.shape[0]
def __getitem__(self,idx):
input_sample = self.inputs[idx]
label_sample = self.labels[idx]
input_sample = input_sample.transpose(2,0,1)
label_sample = label_sample.transpose(2,0,1)
input_sample, label_sample = torch.Tensor(input_sample), torch.Tensor(label_sample)
return input_sample,label_sample
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
# 모델 정의
class SRCNN(nn.Module):
def __init__(self):
super(SRCNN,self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=9),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Conv2d(64,32,kernel_size=1),
nn.ReLU(),
)
self.layer3 = nn.Sequential(
nn.Conv2d(32,3,kernel_size=5),
)
def forward(self,x):
x = self.layer1(x)
# print('1',x.shape)
x = self.layer2(x)
# print('2',x.shape)
x = self.layer3(x)
# print('3',x.shape)
return x
SRCNN( (layer1): Sequential( (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1)) (1): ReLU() ) (layer2): Sequential( (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) (1): ReLU() ) (layer3): Sequential( (0): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1)) ) )
손실함수 optimizer 등 정리
crition = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=1e-2,momentum=(0.9))
train_set = SRdataset()
train_loader = DataLoader(train_set,batch_size=64,shuffle=False)
Train 부분 학습하기
for epoch in range(100):
for i in train_loader:
model.zero_grad()
x, y = i
image_data = Variable(x)
label_data = Variable(y)
output_data = model(image_data)
loss = crition(output_data,label_data)
loss.backward()
optimizer.step()
print(epoch,loss.mean())
torch.save(model,'rgb_model/SRCNN_model%s_%s.pkl'%(epoch+100,float(loss.mean())))
91 tensor(0.0007, grad_fn=<MeanBackward1>) 92 tensor(0.0007, grad_fn=<MeanBackward1>) 93 tensor(0.0007, grad_fn=<MeanBackward1>) 94 tensor(0.0007, grad_fn=<MeanBackward1>) 95 tensor(0.0007, grad_fn=<MeanBackward1>) 96 tensor(0.0007, grad_fn=<MeanBackward1>) 97 tensor(0.0007, grad_fn=<MeanBackward1>) 98 tensor(0.0007, grad_fn=<MeanBackward1>) 99 tensor(0.0007, grad_fn=<MeanBackward1>)
아무 사진이나 test 해보기
plt.imshow(label, cmap='gray')
plt.xticks([]),plt.yticks([])
plt.savefig('label.jpg')
plt.pause(0.005)
plt.imshow(inp, cmap='gray')
plt.xticks([]),plt.yticks([])
plt.savefig('inp.jpg')
plt.pause(0.005)
plt.imshow(test_output, cmap='gray')
plt.xticks([]),plt.yticks([])
plt.savefig('output.jpg')
원본 사진
원본사진을 저화질로 강제 변환
강제로 저화질로 된 사진을 고화질로 변환
'이미지 처리' 카테고리의 다른 글
대학생 1학년이 이해하는 Opencv 와 deel learning을 이용한 영상속에 나를 찾아서 주는 모델 (0) | 2019.01.15 |
---|---|
대학생 1학년이 이해하는 kaggle Art Images 이미지 분류 (1) | 2019.01.08 |