/*
 * Copyright 2016 The Apache Software Foundation.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.storm.kafka.spout;

import static org.apache.storm.kafka.spout.config.builder.SingleTopicKafkaSpoutConfiguration.createKafkaSpoutConfigBuilder;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasKey;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.consumer.*;
import org.apache.kafka.common.TopicPartition;
import org.apache.storm.kafka.spout.config.builder.SingleTopicKafkaSpoutConfiguration;
import org.apache.storm.kafka.spout.internal.ClientFactory;
import org.apache.storm.kafka.spout.subscription.ManualPartitioner;
import org.apache.storm.kafka.spout.subscription.TopicAssigner;
import org.apache.storm.kafka.spout.subscription.TopicFilter;
import org.apache.storm.spout.SpoutOutputCollector;
import org.apache.storm.task.TopologyContext;
import org.apache.storm.utils.Time;
import org.apache.storm.utils.Time.SimulatedTime;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;

@ExtendWith(MockitoExtension.class)
public class KafkaSpoutRebalanceTest {

    @Captor
    private ArgumentCaptor<Map<TopicPartition, OffsetAndMetadata>> commitCapture;

    private final long offsetCommitPeriodMs = 2_000;
    private final Map<String, Object> conf = new HashMap<>();
    @Mock
    private TopologyContext contextMock;
    @Mock
    private SpoutOutputCollector collectorMock;
    @Mock
    private KafkaConsumer<String, String> consumerMock;
    @Mock
    private Admin adminMock;
    private ClientFactory<String, String> clientFactory;
    @Mock
    private TopicFilter topicFilterMock;
    @Mock
    private ManualPartitioner partitionerMock;

    @BeforeEach
    public void setUp() {
        clientFactory = new ClientFactory<String, String>() {
            @Override
            public Consumer<String, String> createConsumer(Map<String, Object> consumerProps) {
                return consumerMock;
            }

            @Override
            public Admin createAdmin(Map<String, Object> adminProps) {
                return adminMock;
            }
        };
        final Answer<Object> set = invocation -> new HashSet<>();
        doAnswer(set).when(topicFilterMock).getAllSubscribedPartitions(any());
        doAnswer(set).when(partitionerMock).getPartitionsForThisTask(any(), any());
    }

    //Returns messageIds in order of emission
    private List<KafkaSpoutMessageId> emitOneMessagePerPartitionThenRevokeOnePartition(KafkaSpout<String, String> spout, TopicPartition partitionThatWillBeRevoked, TopicPartition assignedPartition, TopicAssigner topicAssigner) {
        //Setup spout with mock consumer so we can get at the rebalance listener   
        spout.open(conf, contextMock, collectorMock);
        spout.activate();

        //Assign partitions to the spout
        ArgumentCaptor<ConsumerRebalanceListener> rebalanceListenerCapture = ArgumentCaptor.forClass(ConsumerRebalanceListener.class);
        verify(topicAssigner).assignPartitions(any(), any(), rebalanceListenerCapture.capture());
        ConsumerRebalanceListener consumerRebalanceListener = rebalanceListenerCapture.getValue();
        Set<TopicPartition> assignedPartitions = new HashSet<>();
        assignedPartitions.add(partitionThatWillBeRevoked);
        assignedPartitions.add(assignedPartition);
        consumerRebalanceListener.onPartitionsAssigned(assignedPartitions);
        when(consumerMock.assignment()).thenReturn(assignedPartitions);

        //Make the consumer return a single message for each partition
        when(consumerMock.poll(any(Duration.class)))
            .thenReturn(new ConsumerRecords<>(Collections.singletonMap(partitionThatWillBeRevoked, SpoutWithMockedConsumerSetupHelper.createRecords(partitionThatWillBeRevoked, 0, 1))))
            .thenReturn(new ConsumerRecords<>(Collections.singletonMap(assignedPartition, SpoutWithMockedConsumerSetupHelper.createRecords(assignedPartition, 0, 1))))
            .thenReturn(new ConsumerRecords<>(Collections.emptyMap()));

        //Emit the messages
        spout.nextTuple();
        ArgumentCaptor<KafkaSpoutMessageId> messageIdForRevokedPartition = ArgumentCaptor.forClass(KafkaSpoutMessageId.class);
        verify(collectorMock).emit(anyString(), anyList(), messageIdForRevokedPartition.capture());
        reset(collectorMock);
        spout.nextTuple();
        ArgumentCaptor<KafkaSpoutMessageId> messageIdForAssignedPartition = ArgumentCaptor.forClass(KafkaSpoutMessageId.class);
        verify(collectorMock).emit(anyString(), anyList(), messageIdForAssignedPartition.capture());

        //Now rebalance
        consumerRebalanceListener.onPartitionsRevoked(assignedPartitions);
        consumerRebalanceListener.onPartitionsAssigned(Collections.singleton(assignedPartition));
        final Answer<Object> assignedP = invocation -> Collections.singleton(assignedPartition);
        lenient().doAnswer(assignedP).when(consumerMock).assignment();

        List<KafkaSpoutMessageId> emittedMessageIds = new ArrayList<>();
        emittedMessageIds.add(messageIdForRevokedPartition.getValue());
        emittedMessageIds.add(messageIdForAssignedPartition.getValue());
        return emittedMessageIds;
    }

    @Test
    public void spoutMustIgnoreAcksForTuplesItIsNotAssignedAfterRebalance() {
        // Acking tuples for partitions that are no longer assigned is useless since the spout will not be allowed to commit them
        try (SimulatedTime ignored = new SimulatedTime()) {
            TopicAssigner assignerMock = mock(TopicAssigner.class);
            KafkaSpout<String, String> spout = new KafkaSpout<>(createKafkaSpoutConfigBuilder(topicFilterMock, partitionerMock, -1)
                .setOffsetCommitPeriodMs(offsetCommitPeriodMs)
                .build(), clientFactory, assignerMock);
            String topic = SingleTopicKafkaSpoutConfiguration.TOPIC;
            TopicPartition partitionThatWillBeRevoked = new TopicPartition(topic, 1);
            TopicPartition assignedPartition = new TopicPartition(topic, 2);

            //Emit a message on each partition and revoke the first partition
            List<KafkaSpoutMessageId> emittedMessageIds = emitOneMessagePerPartitionThenRevokeOnePartition(
                spout, partitionThatWillBeRevoked, assignedPartition, assignerMock);

            //Ack both emitted tuples
            spout.ack(emittedMessageIds.get(0));
            spout.ack(emittedMessageIds.get(1));

            //Ensure the commit timer has expired
            Time.advanceTime(offsetCommitPeriodMs + KafkaSpout.TIMER_DELAY_MS);
            //Make the spout commit any acked tuples
            spout.nextTuple();
            //Verify that it only committed the message on the assigned partition
            verify(consumerMock, times(1)).commitSync(commitCapture.capture());

            Map<TopicPartition, OffsetAndMetadata> commitCaptureMap = commitCapture.getValue();
            assertThat(commitCaptureMap, hasKey(assignedPartition));
            assertThat(commitCaptureMap, not(hasKey(partitionThatWillBeRevoked)));
        }
    }

    @Test
    public void spoutMustIgnoreFailsForTuplesItIsNotAssignedAfterRebalance() {
        //Failing tuples for partitions that are no longer assigned is useless since the spout will not be allowed to commit them if they later pass
        TopicAssigner assignerMock = mock(TopicAssigner.class);
        KafkaSpoutRetryService retryServiceMock = mock(KafkaSpoutRetryService.class);
        KafkaSpout<String, String> spout = new KafkaSpout<>(createKafkaSpoutConfigBuilder(topicFilterMock, partitionerMock, -1)
            .setOffsetCommitPeriodMs(10)
            .setRetry(retryServiceMock)
            .build(), clientFactory, assignerMock);
        String topic = SingleTopicKafkaSpoutConfiguration.TOPIC;
        TopicPartition partitionThatWillBeRevoked = new TopicPartition(topic, 1);
        TopicPartition assignedPartition = new TopicPartition(topic, 2);

        when(retryServiceMock.getMessageId(any(TopicPartition.class), anyLong()))
            .thenReturn(new KafkaSpoutMessageId(partitionThatWillBeRevoked, 0))
            .thenReturn(new KafkaSpoutMessageId(assignedPartition, 0));

        //Emit a message on each partition and revoke the first partition
        List<KafkaSpoutMessageId> emittedMessageIds = emitOneMessagePerPartitionThenRevokeOnePartition(
            spout, partitionThatWillBeRevoked, assignedPartition, assignerMock);

        //Check that only two message ids were generated
        verify(retryServiceMock, times(2)).getMessageId(any(TopicPartition.class), anyLong());

        //Fail both emitted tuples
        spout.fail(emittedMessageIds.get(0));
        spout.fail(emittedMessageIds.get(1));

        //Check that only the tuple on the currently assigned partition is retried
        verify(retryServiceMock, never()).schedule(emittedMessageIds.get(0));
        verify(retryServiceMock).schedule(emittedMessageIds.get(1));
    }

    @Test
    public void testReassignPartitionSeeksForOnlyNewPartitions() {
        /*
         * When partitions are reassigned, the spout should seek with the first poll offset strategy for new partitions.
         * Previously assigned partitions should be left alone, since the spout keeps the emitted and acked state for those.
         */

        TopicAssigner assignerMock = mock(TopicAssigner.class);
        KafkaSpout<String, String> spout = new KafkaSpout<>(createKafkaSpoutConfigBuilder(topicFilterMock, partitionerMock, -1)
            .setFirstPollOffsetStrategy(FirstPollOffsetStrategy.UNCOMMITTED_EARLIEST)
            .build(), clientFactory, assignerMock);
        String topic = SingleTopicKafkaSpoutConfiguration.TOPIC;
        TopicPartition assignedPartition = new TopicPartition(topic, 1);
        TopicPartition newPartition = new TopicPartition(topic, 2);

        //Setup spout with mock consumer so we can get at the rebalance listener   
        spout.open(conf, contextMock, collectorMock);
        spout.activate();
        
        ArgumentCaptor<ConsumerRebalanceListener> rebalanceListenerCapture = ArgumentCaptor.forClass(ConsumerRebalanceListener.class);
        verify(assignerMock).assignPartitions(any(), any(), rebalanceListenerCapture.capture());

        //Assign partitions to the spout
        ConsumerRebalanceListener consumerRebalanceListener = rebalanceListenerCapture.getValue();
        Set<TopicPartition> assignedPartitions = new HashSet<>();
        assignedPartitions.add(assignedPartition);
        consumerRebalanceListener.onPartitionsAssigned(assignedPartitions);
        reset(consumerMock);
        
        //Set up committed so it looks like some messages have been committed on each partition
        long committedOffset = 500;
        final Answer<Object> objectAnswer = invocation -> new OffsetAndMetadata(committedOffset);
        lenient().doAnswer(objectAnswer).when(consumerMock).committed(assignedPartition);
        doAnswer(objectAnswer).when(consumerMock).committed(newPartition);

        //Now rebalance and add a new partition
        consumerRebalanceListener.onPartitionsRevoked(assignedPartitions);
        Set<TopicPartition> newAssignedPartitions = new HashSet<>();
        newAssignedPartitions.add(assignedPartition);
        newAssignedPartitions.add(newPartition);
        consumerRebalanceListener.onPartitionsAssigned(newAssignedPartitions);
        
        //This partition was previously assigned, so the consumer position shouldn't change
        verify(consumerMock, never()).seek(eq(assignedPartition), anyLong());
        //This partition is new, and should start at the committed offset
        verify(consumerMock).seek(newPartition, committedOffset);
    }
}
