[Tensorflow] Tensorflow.js

2021. 2. 7. 22:52·딥러닝
반응형
tfjs

자바스크립트는 웹 브라우저에서 주로 사용하는 스크립트 언어이다. 웹 페이지를 동적으로 구성하기 위해 사용된다. 웹 브라우저에서 딥러닝 모델을 사용할 수 있도록 텐서플로우의 자바스크립트 버전이 제공된다.

 

CDN

CDN을 사용하면 서버 입장에서는 제공하려는 컨텐츠를 직접 가지고 있을 필요 없이 사용자가 제공자에게 직접 다운로드 할 수 있다. tfjs의 cdn은 아래와 같이 사용할 수 있다.

xxxxxxxxxx
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

tfjs뿐만 아니라 tfjs-vis를 포함해 학습 과정 및 결과를 웹페이지에 표시하려한다.

 

Dataset

간단히 학습할 수 있는 MNIST 데이터셋을 사용해본다. 0~9까지의 손글씨로 쓴 숫자들로 5500개의 학습 데이터와 1000개의 테스트 데이터를 가진다. 텐서플로 github에서 js 파일로 제공되는 파일을 사용한다. 이 파일을 직접 다운로드해서 data.js파일을 생성한다.

minst

Code

파일 구성은 아래와 같다.

x
DATA_ROOT
├── js
│   ├── data.js
│   └── main.js
└── index.html

data.js는 다운받은 mnist 데이터를 나타낸다. index.html과 main.js는 코드를 작성한다.

기본 웹페이지인 index.html은 아래와 같이 구성한다.

x
<!DOCTYPE html>
<html lang="ko">
  <head>
    <meta charset="utf-8">
  </head>
  <body>
  </body>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>
    <script type="module" src="./js/data.js"></script>
    <script type="module" src="./js/main.js"></script>
</html>

별 다른 내용 없이 자바스크립트 코드를 불러오는 역할만 한다.

main.js의 전체 코드는 여기에서 확인할 수 있다.

 

모델 생성

모델 생성법은 파이썬에서 사용하던 방법과 동일하다.

xxxxxxxxxx
const model = tf.sequential();
model.add(
    tf.layers.conv2d(
        {
            inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
            kernelSize: 3,
            filters: 32,
            strides: 1,
            activation: "relu",
            kernelInitializer: "heUniform"
        }
    )
);
model.add(tf.layers.flatten());
model.add(
    tf.layers.dense(
        {
            units: NUM_OUTPUT_CLASSES,
            kernelInitializer: "heUniform",
            activation: "softmax"
        }
    )
);
​
const optimizer = tf.train.adam(0.001);
​
model.compile({
    optimizer: optimizer,
    loss: "categoricalCrossentropy",
    metrics: ["accuracy"]
});

공식 문서에서 api를 확인할 수 있다.

 

tfvis

tf-vis는 학습에 대한 내용을 웹 페이지에 표시할 수 있는 라이브러리이다. 공식문서에서 클래스나 함수의 사용에 대한 상세한 내용을 확인할 수 있다.

 

Summary

xxxxxxxxxx
const model = getModel();
tfvis.show.modelSummary(
    {
        name: "Model Architecture"
    },
    model
);

01

텐서플로와 동일하게 모델 구조를 출력할 수 있다.

 

Callback

x
const metrics = ["loss", "val_loss", "acc", "val_acc"];
const container = {
    name: "model Training",
    styles:
    {
        height:'10000px'
    }
};
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

02

학습이 진행됨에 따라 loss와 accuracy의 변화하는 모습을 그래프로 나타낸다. 이 콜백 함수를 model.copile의 인자로 넘겨준다.

 

Accuracy

x
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = {
    name: "Accuracy",
    tab: "Evaluation"
};
tfvis.show.perClassAccuracy(container, classAccuracy, classNames);

03

모델의 예측결과를 배열 형태로 전달하고, 각 클래스에 대한 샘플 수와 정확도를 표시한다.

 

Confusion matrix

x
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
const container = {
    name: "Confusion Matrix",
    tab: "Evaluation"
};
tfvis.render.confusionMatrix(
    container,
    {
        values: confusionMatrix
    },
    classNames
);

05

예측 졀과에 대한 confusion matrix를 출력한다. 이를 사용하면 어떤 클래스를 잘못 예측하는지 파악하기 용이하다.

 

 

 

 

반응형
저작자표시 비영리 변경금지 (새창열림)
'딥러닝' 카테고리의 다른 글
  • [딥러닝] MNIST
  • [CUDA] 설치
덴마크초코우유
덴마크초코우유
IT, 알고리즘, 프로그래밍 언어, 자료구조 등 정리
    반응형
  • 덴마크초코우유
    이것저것끄적
    덴마크초코우유
  • 전체
    오늘
    어제
    • 분류 전체보기 (120)
      • Spring Framework (7)
        • Spring (3)
        • JPA (2)
        • Spring Security (0)
      • Language (51)
        • Java (11)
        • Python (10)
        • JavaScript (5)
        • NUXT (2)
        • C C++ (15)
        • PHP (8)
      • DB (16)
        • MySQL (10)
        • Reids (3)
        • Memcached (2)
      • 개발 (3)
      • 프로젝트 (2)
      • Book (2)
      • PS (15)
        • 기타 (2)
        • 백준 (2)
        • 프로그래머스 (10)
      • 딥러닝 (8)
        • CUDA (0)
        • Pytorch (0)
        • 모델 (0)
        • 컴퓨터 비전 (4)
        • OpenCV (1)
      • 기타 (16)
        • 디자인패턴 (2)
        • UnrealEngine (8)
        • ubuntu (1)
        • node.js (1)
        • 블로그 (1)
  • 블로그 메뉴

    • 홈
    • 태그
    • 미디어로그
    • 위치로그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    알고리즘
    select
    MySQL
    클래스
    mscoco
    프로그래머스
    Python
    딥러닝
    Unreal
    memcached
    JavaScript
    C
    php
    pytorch
    PS
    redis
    자바
    C++
    파이썬
    블루프린트
    웹
    map
    NUXT
    CPP
    게임 개발
    게임
    FPS
    JS
    Unreal Engine
    언리얼엔진4
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
덴마크초코우유
[Tensorflow] Tensorflow.js
상단으로

티스토리툴바