Modern Java in Action - Ch.7

참고: 책 - Modern Java in Action

책 Modern Java in Action을 읽고 정리합니다. 이번 포스트에서는 Ch 7.1 ~ Ch 7.4의 내용을 읽고 정리합니다.

Ch 7. 병렬 데이터 처리와 성능
7.1 병렬 스트림
- 7.1.1 순차 스트림을 병렬 스트림으로 변환하기
- 7.1.2 스트림 성능 측정
- 7.1.3 병렬 스트림의 올바른 사용법
- 7.1.4 병렬 스트림 효과적으로 사용하기
7.2 포크/조인 프레임워크
- 7.2.1 RecursiveTask 활용
- 7.2.2 포크/조인 프레임워크를 제대로 사용하는 방법
- 7.2.3 작업 훔치기
7.3 Splitator 인터페이스
- 7.3.1 분할 과정
- 7.3.2 커스텀 Spliterator 구현하기

Java 7이 등장하기 전… 데이터 컬렉션을 병렬로 처리하기가 어려웠습니다.

Java 7은… 더 쉽게 병렬화를 수행하면서 에러를 최소화 할 수 있도록 Fork/Join Framework 기능을 제공합니다. 7장에서는 스트림으로 데이터 컬렉션 관련 동작을 얼마나 쉽게 병렬로 실행할 수 있는지를 다룹니다.




7.1 병렬 스트림

컬렉션에 parallelStream을 호출하면 병렬 스트림이 생성됩니다. 병렬 스트림은 각각의 스레드에서 처리할 수 있도록 스트림 요소를 여러 청크로 분할한 스트림입니다. 따라서 병렬 스트림을 이용하면 모든 멀티코어 프로세서가 각각의 청크를 처리하도록 할당할 수 있습니다.

ex. 숫자 n을 인수로 받아서 1부터 n까지의 모든 숫자의 합계를 반환하는 메서드를 구현하는 예제입니다.

public long sequentialSum(long n) {
    return Stream.iterate(1L, i -> i + 1) //무한 자연수 스트림 생성
                  .limit(n) //주어진 크기로 스트림 제한
                  .reduce(0L, Long::sum); //모든 숫자를 더하는 스트림 리듀싱 연산
}
public long iterativeSum(long n) {
    long result = 0;
    for (long i = 1L; i <= n; i++) {
        result += i;
    }
    return result;
}




7.1.1 순차 스트림을 병렬 스트림으로 변환하기

public long parallelSum(long n) {
    return Stream.iterate(1L, i -> i + 1)
                 .limit(n)
                 .parallel() //스트림을 병렬 스트림으로 변환
                 .reduce(0L, Long::sum);
}

IMG_5966408BBC1A-1

stream.parallel()
        .filter(...)
        .sequential()
        .map(...)
        .parallel()
        .reduce();




7.1.2 스트림 성능 측정

<!--핵심 JMH 구현을 포함-->
<dependency>
    <groupId>org.openjdk.jmh</groupId>
    <artifactId>jmh-generator-annprocess</artifactId>
    <version>1.35</version>
</dependency>
<!--JAR 파일을 만드는데 도움을 주는 어노테이션 프로세서를 포함-->
<dependency>
    <groupId>org.openjdk.jmh</groupId>
    <artifactId>jmh-core</artifactId>
    <version>1.35</version>
</dependency>
<build>
    <plugins>
        <plugin>
            <groupId>org.apache.maven.plugins</groupId>
            <artifactId>maven-shade-plugin</artifactId>
            <executions>
                <execution>
                    <phase>package</phase>
                    <goals><goal>shade</goal></goals>
                    <configuration>
                        <finalName>benchmarks</finalName>
                        <transformers>
                            <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                <mainClass>org.openjdk.jmh.Main</mainClass>
                            </transformer>
                        </transformers>
                    </configuration>
                </execution>
            </executions>
        </plugin>
    </plugins>
</build>


ex. n개의 숫자를 더하는 함수의 성능을 측정하는 예제입니다.

sequentialSum - 순차적 스트림 사용

@State(Scope.Thread)
@BenchmarkMode(Mode.AverageTime) //벤치마크 대상 메서드를 실행하는데 걸린 평균 시간 측정
@OutputTimeUnit(TimeUnit.MILLISECONDS) //벤치마크 결과를 밀리초 단위로 출력
@Fork(value = 2, jvmArgs = {"-Xms4G", "-Xms4G"}) //4Gb의 힙 공간을 제공한 환경에서 두 번 벤치마크를 수행해 결과의 신뢰성 확보
@Measurement(iterations = 20)
@Warmup(iterations = 3)
public class ParallelStreamBenchmark {
    private static final long N = 10_000_000L;

    @Benchmark
    public long seqeuntialSum() {
        return Stream.iterate(1L, i -> i + 1).limit(N)
                .reduce(0L, Long::sum);
    }

    @TearDown(Level.Invocation) //매 번 벤치마크를 실행한 다음에는 가비지 컬렉터 동작 시도
    public void tearDown() {
        System.gc();
    }
}

스크린샷 2022-08-17 오후 12 44 09



iterativeSum - 기본 for문 사용

@State(Scope.Thread)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(value = 2, jvmArgs = { "-Xms4G", "-Xmx4G" })
@Measurement(iterations = 20)
@Warmup(iterations = 3)
public class ParallelStreamBenchmark {

  private static final long N = 10_000_000L;

  @Benchmark
  public long iterativeSum() {
    long result = 0;
    for (long i = 1L; i <= N; i++) {
      result += i;
    }
    return result;
  }

  @TearDown(Level.Invocation) //매 번 벤치마크를 실행한 다음에는 가비지 컬렉터 동작 시도
  public void tearDown() {
    System.gc();
  }
}

image



parallelSum - 병렬 스트림 사용

@State(Scope.Thread)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(value = 2, jvmArgs = { "-Xms4G", "-Xmx4G" })
@Measurement(iterations = 20)
@Warmup(iterations = 3)
public class ParallelStreamBenchmark {

  @Benchmark
  public long parallelSum() {
    return Stream.iterate(1L, i -> i + 1).limit(N).parallel().reduce(0L, Long::sum);
  }

  @TearDown(Level.Invocation) //매 번 벤치마크를 실행한 다음에는 가비지 컬렉터 동작 시도
  public void tearDown() {
    System.gc();
  }
}

스크린샷 2022-08-17 오후 1 14 28



rangedSum

@State(Scope.Thread)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(value = 2, jvmArgs = { "-Xms4G", "-Xmx4G" })
@Measurement(iterations = 20)
@Warmup(iterations = 3)
public class ParallelStreamBenchmark {

  private static final long N = 10_000_000L;
  
  @Benchmark
  public long rangedSum() {
    return LongStream.rangeClosed(1, N).reduce(0L, Long::sum);
  }
  
  @TearDown(Level.Invocation)
  public void tearDown() {
    System.gc();
  }

}

image



parallelRangedSum - 새로운 버전에 병렬 스트림 적용

@State(Scope.Thread)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(value = 2, jvmArgs = { "-Xms4G", "-Xmx4G" })
@Measurement(iterations = 20)
@Warmup(iterations = 3)
public class ParallelStreamBenchmark {

  private static final long N = 10_000_000L;
  
  @Benchmark
  public long parallelRangedSum() {
    return LongStream.rangeClosed(1, N).parallel().reduce(0L, Long::sum);
  }
  
  @TearDown(Level.Invocation)
  public void tearDown() {
    System.gc();
  }

}

image




7.1.3 병렬 스트림의 올바른 사용법

스트림을 병렬화해서 코드 실행 속도를 빠르게 하고 싶으면 항상 병렬화를 올바르게 사용하고 있는지 확인해야 합니다. 병렬 스트림을 잘못 사용하면서 발생하는 많은 문제는 공유된 상태를 바꾸는 알고리즘을 사용하기 때문에 일어납니다.

ex. n까지의 자연수를 더하면서 공유된 누적자를 바꾸는 프로그램을 구현한 예제 입니다.

public class ParallelStreams {
    
  public static long sideEffectSum(long n) {
    Accumulator accumulator = new Accumulator();
    LongStream.rangeClosed(1, n).forEach(accumulator::add);
    return accumulator.total;
  }

  public static class Accumulator {

    private long total = 0;

    public void add(long value) {
      total += value;
    }

  }
}

public class ParallelStreams {

  public static long sideEffectParallelSum(long n) {
    Accumulator accumulator = new Accumulator();
    LongStream.rangeClosed(1, n).parallel().forEach(accumulator::add);
    return accumulator.total;
  }
}
public class ParallelStreamsHarness {
  public static void main(String[] args) {
    System.out.println("SideEffect parallel sum done in: " + measurePerf(ParallelStreams::sideEffectParallelSum, 10_000_000L) + " msecs" );
  }
}

image




7.1.4 병렬 스트림 효과적으로 사용하기

어떤 상황에서 병렬 스트림을 사용할 것인지 약간의 수량적 힌트를 정하는 것이 도움이 될 때도 있습니다. 아래 기준들을 통해 알아보도록 하겠습니다.

확신이 서지 않으면 직접 측정하라.

박싱을 주의하라.

순차 스트림보다 병렬 스트림에서 성능이 떨어지는 연산이 있다.

스트림에서 수행하는 전체 파이프라인 연산 비용을 고려하라.

소량의 데이터에서는 병렬 스트림이 도움 되지 않는다.

스트림을 구성하는 자료구조가 적절한지 확인하라.

스트림의 특성과 파이프라인의 중간 연산이 스트림의 특성을 어떻게 바꾸는지에 따라 분해 과정의 성능이 달라질 수 있다.

최종 연산의 병합 과정 비용을 살펴보라.

참고: 스트림 소스의 병렬화 친밀도(분해성)




7.2 포크/조인 프레임워크

포크/조인 프레임워크는 병렬화할 수 있는 작업을 재귀적으로 작은 작업으로 분할한 다음에 서브태스크 각각의 결과를 합쳐서 전체 결과를 만들도록 설계되었습니다. 포크/조인 프레임워크에서는 서브 태스크를 스레드 풀(ForkJoinPool)의 작업자 스레드에 분산 할당하는 ExecutorService 인터페이스를 구현합니다.




7.2.1 Recursive Task 활용

스레드 풀을 이용하려면 RecursiveTask<R>의 서브클래스를 만들어야 합니다. 여기서 R은 병렬화된 태스크가 생성하는 결과 형식 또는 결과가 없을 때는 RecursiveAction 형식입니다. RecursiveTask를 정의하려면 추상 메서드 compute를 구현해야 합니다.

protected abstract R compute();
if (태스크가 충분히 작거나 더 이상 분할할 수 없으면) {
    순차적으로 태스크 계산
} else {
    태스크를 두 서브태스크로 분할
    태스크가 다시 서브태스크로 분할되도록 이 메서드를 재귀적으로 호출함
    모든 서브태스크의 연산이 완료될 때까지 기다림
    각 서브태스크의 결과를 합침
}

재귀적인 태스크 분할 과정 (포크/조인 과정)

image

예제를 통해 포크/조인 프레임워크를 사용하는 방법을 확인해보기

ex. 포크/조인 프레임워크를 이용해서 범위의 숫자를 더하는 예제입니다.(long[]으로 이루어진 숫자 배열 사용)

public class ParallelStreamsHarness {
  public static final ForkJoinPool FORK_JOIN_POOL = new ForkJoinPool();
}
import static modernjavainaction.chap07.ParallelStreamsHarness.FORK_JOIN_POOL;

import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
import java.util.stream.LongStream;

//RecursiveTask를 상속받아 포크/조인 프레임워크에서 사용할 태스크를 생성
public class ForkJoinSumCalculator extends RecursiveTask<Long> {

  //이 값 이하의 서브태스크는 더 이상 분할할 수 없음
  public static final long THRESHOLD = 10_000;

  private final long[] numbers; //더할 숫자 배열
  private final int start; //이 서브태스크에서 처리할 배열의 초기 위치
  private final int end; //최종 위치

  //메인 태스크를 생성할 때 사용할 공개 생성자
  public ForkJoinSumCalculator(long[] numbers) {
    this(numbers, 0, numbers.length);
  }

  //메인 태스크의 서브태스크를 재귀적으로 만들 때 사용할 비공개 생성자
  private ForkJoinSumCalculator(long[] numbers, int start, int end) {
    this.numbers = numbers;
    this.start = start;
    this.end = end;
  }

  //RecursiveTask의 추상 메서드 오버라이드
  @Override
  protected Long compute() {
    //태스크에서 더할 배열의 길이
    int length = end - start;
    if (length <= THRESHOLD) {
      return computeSequentially(); //기준값과 같거나 작으면 순차적으로 결과를 계산
    }
    //배열의 첫 번째 절반을 더하도록 서브 태스크를 생성
    ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length / 2);
    leftTask.fork(); //ForkJoinPool의 다른 스레드로 새로 생성한 태스크를 비동기로 실행
    //배열의 나머지 절반을 더하도록 서브 태스크를 생성
    ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length / 2, end);
    Long rightResult = rightTask.compute(); //두 번째 서브태스크를 동기 실행(이때 추가로 분할이 일어날 수 있음)
    Long leftResult = leftTask.join(); //첫 번째 서브태스크의 결과를 읽거나 아직 결과가 없으면 기다림
    return leftResult + rightResult; //두 서브 태스크의 결과를 조합한 값이 이 태스크의 결과
  }

  //더 분할할 수 없을 때 서브태스크의 결과를 계산하는 단순한 알고리즘
  private long computeSequentially() {
    long sum = 0;
    for (int i = start; i < end; i++) {
      sum += numbers[i];
    }
    return sum;
  }

  /*
  위 코드는 n까지의 자연수 덧셈 작업을 병렬로 수행하는 방법을 더 직관적으로 보여줍니다.
  ForkJoinSumCalculator의 생성자로 원하는 수의 배열을 넘겨줄 수 있습니다.
  */

  public static long forkJoinSum(long n) {
    //LongStream으로 n까지의 자연수를 포함하는 배열을 생성
    long[] numbers = LongStream.rangeClosed(1, n).toArray();
    //생성된 배열을 ForkJoinSumCalculator의 생성자로 전달해서 ForkJoinTask를 생성
    ForkJoinTask<Long> task = new ForkJoinSumCalculator(numbers);
    //생성한 태스크를 새로운 ForkJoinPool의 invoke 메서드로 전달
    return FORK_JOIN_POOL.invoke(task);
    //ForkJoinPool에서 실행되는 마지막 invoke 메서드의 반환값은 ForkJoinSumCalculator에서 정의한 태스크의 결과가 됩니다.
  }

}

ForkJoinSumCalculator 실행

IMG_C82C60BF1E58-1




7.2.2 포크/조인 프레임워크를 제대로 사용하는 방법




7.2.3 작업 훔치기

이론적으로는 코어 개수만큼 병렬화된 태스크로 작업부하를 분할하면 모든 CPU 코어에서 태스크를 실행할 것이고 크기가 같은 각각의 태스크는 같은 시간에 종료될 것이라고 생각할 수 있습니다. 그러나 현실에서는 각각의 서브태스크의 작업완료 시간이 크게 달라질 수 있습니다. 분할 기법이 효율적이지 않았기 때문일 수도 있고 아니면 예기치 않게 디스크 접근 속도가 저하되었거나 외부 서비스와 협력하는 과정에서 지연이 생길 수 있기 때문입니다.


포크/조인 프레임워크에서는 작업 훔치기라는 기법으로 이 문제를 해결합니다.

image




7.3 Spliterator 인터페이스

Spliterator는 분할할 수 있는 반복자라는 의미로 Iterator처럼 소스의 요소 탐색 기능을 제공한다는 점은 같지만 병렬 작업에 특화되어 있습니다. Java 8은 컬렉션 프레임워크에 포함된 모든 자료구조에 사용할 수 있는 디폴트 Spliterator 구현을 제공합니다. 컬렉션은 spliterator라는 메서드를 제공하는 Splitator 인터페이스를 구현합니다.

Spliterator 인터페이스

public interface Spliterator<T> { // T: 탐색하는 요소의 형식

    // tryAdvance: Spliterator의 요소를 하나씩 순차적으로 소비하면서 탐색해야 할 요소가 남아있으면 참을 반환 (Iterator 처럼)
    boolean tryAdvance(Consumer<? super T> action);
    
    // trySplit: Spliterator의 일부 요소(자신이 반환한 요소)를 분할해서 두 번째 Spliterator를 생성하는 메서드
    Spliterator<T> trySplit();
    
    // estimateSize: 탐색해야 할 요소 수 정보를 제공
    long estimateSize();
    
    int characteristics();
}




7.3.1 분할 과정

IMG_86D1C0BC329A-1


IMG_ED12E934DFB0-1

Spliterator 특성

int characteristics();




7.3.2 커스텀 Spliterator 구현하기

ex. 문자열의 단어 수를 계산하는 메서드를 구현하는 예제입니다.

import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class WordCount {

  public static final String SENTENCE =
      " Nel   mezzo del cammin  di nostra  vita "
      + "mi  ritrovai in una  selva oscura"
      + " che la  dritta via era   smarrita ";

  public static void main(String[] args) {
    System.out.println("Found " + countWordsIteratively(SENTENCE) + " words");
  }

  public static int countWordsIteratively(String s) {
    int counter = 0;
    boolean lastSpace = true;
    //문자열의 모든 문자를 하나씩 탐색
    for (char c : s.toCharArray()) {
      if (Character.isWhitespace(c)) {
        lastSpace = true;
      }
      else {
        //문자를 하나씩 탐색하다 공백 문자를 만나면 지금까지 탐색한 문자로 간주하여(공백 문자는 제외) 단어 수를 증가
        if (lastSpace) {
          counter++;
        }
        lastSpace = Character.isWhitespace(c);
      }
    }
    return counter;
  }
}

결과

image



함수형으로 단어 수를 세는 메서드 재구현하기

Stream<Character> stream = IntStream.range(0, SENTENCE.length()).mapToObj(SENTENCE::charAt);
private static class WordCounter {

  private final int counter;
  private final boolean lastSpace;

  public WordCounter(int counter, boolean lastSpace) {
    this.counter = counter;
    this.lastSpace = lastSpace;
  }

  //문자열의 문자를 하나씩 탐색
  public WordCounter accumulate(Character c) {
    if (Character.isWhitespace(c)) {
      return lastSpace ? this : new WordCounter(counter, true);
    } else {
        //문자를 하나씩 탐색하다가 공백 문자를 만나면 지금까지 탐색한 문자를 단어로 간주, 단어 수 증가
      return lastSpace ? new WordCounter(counter + 1, false) : this;
    }
  }

  //두 WordCounter의 counter값을 더함
  public WordCounter combine(WordCounter wordCounter) {
    return new WordCounter(counter + wordCounter.counter, wordCounter.lastSpace); // 마지막 공백은 신경 안씀
  }

  public int getCounter() {
    return counter;
  }

}

private static int countWords(Stream<Character> stream) {
    WordCounter wordCounter = stream.reduce(new WordCounter(0, true), WordCounter::accumulate, WordCounter::combine);
    return wordCounter.getCounter();
}
public static void main(String[] args) {
    System.out.println("Found " + countWords(SENTENCE) + " words");
}

결과

image



WordCounter 병렬로 수행하기

public static void main(String[] args) {
    System.out.println("Found " + countWords(SENTENCE) + " words");
}

public static int countWords(String s) {
    Stream<Character> stream = IntStream.range(0, s.length()).mapToObj(SENTENCE::charAt).parallel();
    Spliterator<Character> spliterator = new WordCounterSpliterator(s);
    return countWords(stream);
}

결과

image

private static class WordCounterSpliterator implements Spliterator<Character> {

    private final String string;
    private int currentChar = 0;

    private WordCounterSpliterator(String string) {
        this.string = string;
    }

    @Override
    public boolean tryAdvance(Consumer<? super Character> action) {
        //현재 문자를 소비
        action.accept(string.charAt(currentChar++));
        //소비할 문자가 남아있으면 true 반환
        return currentChar < string.length();
    }

    @Override
    public Spliterator<Character> trySplit() {
        int currentSize = string.length() - currentChar;
        //파싱할 문자열을 순차 처리할 수 있을만큼 충분히 작아졌음을 알리는 null 반환
        if (currentSize < 10) {
            return null;
        }
        //파싱할 문자열의 중간을 분할 위치로 설정
        for (int splitPos = currentSize / 2 + currentChar; splitPos < string.length(); splitPos++) {
            //다음 공백이 나올 때까지 분할 위치를 뒤로 이동시킴
            if (Character.isWhitespace(string.charAt(splitPos))) {
                //처음부터 분할 위치까지 문자열을 파싱할 새로운 WordCounterSpliterator를 생성
                Spliterator<Character> spliterator = new WordCounterSpliterator(string.substring(currentChar, splitPos));
                //이 WordCounterSpliterator의 시작 위치를 분할 위치로 설정
                currentChar = splitPos;
                //공백을 찾았고 문자열을 분리했으므로 루프를 종료
                return spliterator;
            }
        }
        return null;
    }

    //탐색해야 할 요소의 개수
    @Override
    public long estimateSize() {
        //파싱할 문자열 전체 길이 - 현재 반복중인 위치
        return string.length() - currentChar;
    }

    @Override
    public int characteristics() {
        /*
          ORDERED: 문자열의 문자 등장 순서가 유의미함
          SIZED: estimatedSize 메서드의 반환값이 정확함
          SUBSIZED: trySplit으로 생성된 Spliterator도 정확한 크기를 가짐
          NONNULL: 문자열에는 null 문자가 존재하지 않음
          IMMUTABLE: 문자열 자체가 불변 클래스이므로 문자열을 파싱하면서 속성이 추가되지 않음
        */
        return ORDERED + SIZED + SUBSIZED + NONNULL + IMMUTABLE;
    }

}

결과 - WordCounterSpliterator 활용

public static void main(String[] args) {
    System.out.println("Found " + countWords(SENTENCE) + " words");
}

public static int countWords(String s) {
    Spliterator<Character> spliterator = new WordCounterSpliterator(s);
    //true는 병렬 스트림 생성 여부를 지시
    Stream<Character> stream = StreamSupport.stream(spliterator, true);

    return countWords(stream);
}

image