/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.kafka.jmh.consumer;

import org.apache.kafka.clients.Metadata;
import org.apache.kafka.clients.consumer.OffsetResetStrategy;
import org.apache.kafka.clients.consumer.internals.SubscriptionState;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.utils.LogContext;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 5)
@Measurement(iterations = 15)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public class SubscriptionStateBenchmark {
    @Param({"5000"})
    int topicCount;

    @Param({"50"})
    int partitionCount;

    SubscriptionState subscriptionState;

    @Setup(Level.Trial)
    public void setup() {
        Set<TopicPartition> assignment = new HashSet<>(topicCount * partitionCount);
        IntStream.range(0, topicCount).forEach(topicId ->
            IntStream.range(0, partitionCount).forEach(partitionId ->
                assignment.add(new TopicPartition(String.format("topic-%04d", topicId), partitionId))
            )
        );
        subscriptionState = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
        subscriptionState.assignFromUser(assignment);
        SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
            0L,
            Optional.of(0),
            new Metadata.LeaderAndEpoch(Optional.of(new Node(0, "host", 9092)), Optional.of(10))
        );
        assignment.forEach(topicPartition -> {
            subscriptionState.seekUnvalidated(topicPartition, position);
            subscriptionState.completeValidation(topicPartition);
        });
    }

    @Benchmark
    public boolean testHasAllFetchPositions() {
        return subscriptionState.hasAllFetchPositions();
    }

    @Benchmark
    public int testFetchablePartitions() {
        return subscriptionState.fetchablePartitions(tp -> true).size();
    }

    @Benchmark
    public int testPartitionsNeedingValidation() {
        return subscriptionState.partitionsNeedingValidation(0L).size();
    }
}
