/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.indices.replication;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CancellableThreads;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.index.shard.IndexEventListener;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardState;
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.recovery.FileChunkRequest;
import org.opensearch.indices.recovery.ForceSyncRequest;
import org.opensearch.indices.recovery.RecoverySettings;
import org.opensearch.indices.replication.SegmentReplicationSourceFactory;
import org.opensearch.indices.replication.SegmentReplicationState;
import org.opensearch.indices.replication.SegmentReplicationTarget;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.indices.replication.common.ReplicationCollection;
import org.opensearch.indices.replication.common.ReplicationFailedException;
import org.opensearch.indices.replication.common.ReplicationListener;
import org.opensearch.indices.replication.common.ReplicationState;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportRequestHandler;
import org.opensearch.transport.TransportResponse;
import org.opensearch.transport.TransportService;

public class SegmentReplicationTargetService
implements IndexEventListener {
    private static final Logger logger = LogManager.getLogger(SegmentReplicationTargetService.class);
    private final ThreadPool threadPool;
    private final RecoverySettings recoverySettings;
    private final ReplicationCollection<SegmentReplicationTarget> onGoingReplications;
    private final Map<ShardId, SegmentReplicationTarget> completedReplications = ConcurrentCollections.newConcurrentMap();
    private final SegmentReplicationSourceFactory sourceFactory;
    private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = ConcurrentCollections.newConcurrentMap();
    private final IndicesService indicesService;
    public static final SegmentReplicationTargetService NO_OP = new SegmentReplicationTargetService(){

        @Override
        public void beforeIndexShardClosed(ShardId shardId, IndexShard indexShard, Settings indexSettings) {
        }

        @Override
        public synchronized void onNewCheckpoint(ReplicationCheckpoint receivedCheckpoint, IndexShard replicaShard) {
        }

        @Override
        public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) {
        }
    };

    private SegmentReplicationTargetService() {
        this.threadPool = null;
        this.recoverySettings = null;
        this.onGoingReplications = null;
        this.sourceFactory = null;
        this.indicesService = null;
    }

    public ReplicationCollection.ReplicationRef<SegmentReplicationTarget> get(long replicationId) {
        return this.onGoingReplications.get(replicationId);
    }

    public SegmentReplicationTargetService(ThreadPool threadPool, RecoverySettings recoverySettings, TransportService transportService, SegmentReplicationSourceFactory sourceFactory, IndicesService indicesService) {
        this.threadPool = threadPool;
        this.recoverySettings = recoverySettings;
        this.onGoingReplications = new ReplicationCollection(logger, threadPool);
        this.sourceFactory = sourceFactory;
        this.indicesService = indicesService;
        transportService.registerRequestHandler("internal:index/shard/replication/file_chunk", "generic", FileChunkRequest::new, new FileChunkTransportRequestHandler());
        transportService.registerRequestHandler("internal:index/shard/replication/segments_sync", "generic", ForceSyncRequest::new, new ForceSyncTransportRequestHandler());
    }

    @Override
    public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) {
        if (indexShard != null) {
            this.onGoingReplications.cancelForShard(shardId, "shard closed");
        }
    }

    @Override
    public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) {
        if (oldRouting != null && !oldRouting.primary() && newRouting.primary()) {
            this.onGoingReplications.cancelForShard(indexShard.shardId(), "shard has been promoted to primary");
        }
    }

    @Nullable
    public SegmentReplicationState getOngoingEventSegmentReplicationState(ShardId shardId) {
        return Optional.ofNullable(this.onGoingReplications.getOngoingReplicationTarget(shardId)).map(SegmentReplicationTarget::state).orElse(null);
    }

    @Nullable
    public SegmentReplicationState getlatestCompletedEventSegmentReplicationState(ShardId shardId) {
        return Optional.ofNullable(this.completedReplications.get(shardId)).map(SegmentReplicationTarget::state).orElse(null);
    }

    @Nullable
    public SegmentReplicationState getSegmentReplicationState(ShardId shardId) {
        return Optional.ofNullable(this.getOngoingEventSegmentReplicationState(shardId)).orElseGet(() -> this.getlatestCompletedEventSegmentReplicationState(shardId));
    }

    public synchronized void onNewCheckpoint(ReplicationCheckpoint receivedCheckpoint, final IndexShard replicaShard) {
        logger.trace(() -> new ParameterizedMessage("Replica received new replication checkpoint from primary [{}]", (Object)receivedCheckpoint));
        if (replicaShard.state().equals((Object)IndexShardState.STARTED)) {
            SegmentReplicationTarget ongoingReplicationTarget;
            if (this.latestReceivedCheckpoint.get(replicaShard.shardId()) != null) {
                if (receivedCheckpoint.isAheadOf(this.latestReceivedCheckpoint.get(replicaShard.shardId()))) {
                    this.latestReceivedCheckpoint.replace(replicaShard.shardId(), receivedCheckpoint);
                }
            } else {
                this.latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint);
            }
            if ((ongoingReplicationTarget = this.onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId())) != null) {
                if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) {
                    logger.trace("Cancelling ongoing replication from old primary with primary term {}", (Object)ongoingReplicationTarget.getCheckpoint().getPrimaryTerm());
                    this.onGoingReplications.cancel(ongoingReplicationTarget.getId(), "Cancelling stuck target after new primary");
                    this.completedReplications.put(replicaShard.shardId(), ongoingReplicationTarget);
                } else {
                    logger.trace(() -> new ParameterizedMessage("Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}", (Object)replicaShard.getLatestReplicationCheckpoint()));
                    return;
                }
            }
            final Thread thread = Thread.currentThread();
            if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) {
                this.startReplication(receivedCheckpoint, replicaShard, new SegmentReplicationListener(){

                    @Override
                    public void onReplicationDone(SegmentReplicationState state) {
                        logger.trace(() -> new ParameterizedMessage("[shardId {}] [replication id {}] Replication complete, timing data: {}", new Object[]{replicaShard.shardId().getId(), state.getReplicationId(), state.getTimingData()}));
                        if (SegmentReplicationTargetService.this.latestReceivedCheckpoint.get(replicaShard.shardId()).isAheadOf(replicaShard.getLatestReplicationCheckpoint())) {
                            Runnable runnable = () -> SegmentReplicationTargetService.this.onNewCheckpoint(SegmentReplicationTargetService.this.latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard);
                            if (thread == Thread.currentThread()) {
                                SegmentReplicationTargetService.this.threadPool.generic().execute(runnable);
                            } else {
                                runnable.run();
                            }
                        }
                    }

                    @Override
                    public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) {
                        logger.trace(() -> new ParameterizedMessage("[shardId {}] [replication id {}] Replication failed, timing data: {}", new Object[]{replicaShard.shardId().getId(), state.getReplicationId(), state.getTimingData()}));
                        if (sendShardFailure) {
                            logger.error("replication failure", (Throwable)e);
                            replicaShard.failShard("replication failure", e);
                        }
                    }
                });
            }
        }
    }

    public SegmentReplicationTarget startReplication(ReplicationCheckpoint checkpoint, IndexShard indexShard, SegmentReplicationListener listener) {
        SegmentReplicationTarget target = new SegmentReplicationTarget(checkpoint, indexShard, this.sourceFactory.get(indexShard), (ReplicationListener)listener);
        this.startReplication(target);
        return target;
    }

    void startReplication(SegmentReplicationTarget target) {
        long replicationId = this.onGoingReplications.start(target, this.recoverySettings.activityTimeout());
        this.threadPool.generic().execute(new ReplicationRunner(replicationId));
    }

    private void start(final long replicationId) {
        try (ReplicationCollection.ReplicationRef<SegmentReplicationTarget> replicationRef = this.onGoingReplications.get(replicationId);){
            if (replicationRef == null) {
                return;
            }
            final SegmentReplicationTarget target = this.onGoingReplications.getTarget(replicationId);
            ((SegmentReplicationTarget)((Object)replicationRef.get())).startReplication(new ActionListener<Void>(){

                @Override
                public void onResponse(Void o) {
                    SegmentReplicationTargetService.this.onGoingReplications.markAsDone(replicationId);
                    if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0L) {
                        SegmentReplicationTargetService.this.completedReplications.put(target.shardId(), target);
                    }
                }

                @Override
                public void onFailure(Exception e) {
                    Throwable cause = ExceptionsHelper.unwrapCause(e);
                    if (cause instanceof CancellableThreads.ExecutionCancelledException) {
                        if (SegmentReplicationTargetService.this.onGoingReplications.getTarget(replicationId) != null) {
                            IndexShard indexShard = SegmentReplicationTargetService.this.onGoingReplications.getTarget(replicationId).indexShard();
                            SegmentReplicationTargetService.this.onGoingReplications.fail(replicationId, new ReplicationFailedException(indexShard, cause), false);
                            SegmentReplicationTargetService.this.completedReplications.put(target.shardId(), target);
                        }
                    } else {
                        SegmentReplicationTargetService.this.onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", (Throwable)e), true);
                    }
                }
            });
        }
    }

    public static class Actions {
        public static final String FILE_CHUNK = "internal:index/shard/replication/file_chunk";
        public static final String FORCE_SYNC = "internal:index/shard/replication/segments_sync";
    }

    private class FileChunkTransportRequestHandler
    implements TransportRequestHandler<FileChunkRequest> {
        final AtomicLong bytesSinceLastPause = new AtomicLong();

        private FileChunkTransportRequestHandler() {
        }

        @Override
        public void messageReceived(FileChunkRequest request, TransportChannel channel, Task task) throws Exception {
            try (ReplicationCollection.ReplicationRef<SegmentReplicationTarget> ref = SegmentReplicationTargetService.this.onGoingReplications.getSafe(request.recoveryId(), request.shardId());){
                SegmentReplicationTarget target = (SegmentReplicationTarget)((Object)ref.get());
                ActionListener<Void> listener = target.createOrFinishListener(channel, "internal:index/shard/replication/file_chunk", request);
                target.handleFileChunk(request, target, this.bytesSinceLastPause, SegmentReplicationTargetService.this.recoverySettings.rateLimiter(), listener);
            }
        }
    }

    private class ForceSyncTransportRequestHandler
    implements TransportRequestHandler<ForceSyncRequest> {
        private ForceSyncTransportRequestHandler() {
        }

        @Override
        public void messageReceived(ForceSyncRequest request, final TransportChannel channel, Task task) throws Exception {
            assert (SegmentReplicationTargetService.this.indicesService != null);
            final IndexShard indexShard = (IndexShard)SegmentReplicationTargetService.this.indicesService.getShardOrNull(request.getShardId());
            SegmentReplicationTargetService.this.startReplication(ReplicationCheckpoint.empty(request.getShardId()), indexShard, new SegmentReplicationListener(){

                @Override
                public void onReplicationDone(SegmentReplicationState state) {
                    logger.trace(() -> new ParameterizedMessage("[shardId {}] [replication id {}] Replication complete, timing data: {}", new Object[]{indexShard.shardId().getId(), state.getReplicationId(), state.getTimingData()}));
                    try {
                        if (indexShard.recoveryState().getPrimary()) {
                            indexShard.resetToWriteableEngine();
                        }
                        channel.sendResponse(TransportResponse.Empty.INSTANCE);
                    }
                    catch (IOException | InterruptedException | TimeoutException e) {
                        throw new RuntimeException(e);
                    }
                }

                @Override
                public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) {
                    logger.trace(() -> new ParameterizedMessage("[shardId {}] [replication id {}] Replication failed, timing data: {}", new Object[]{indexShard.shardId().getId(), state.getReplicationId(), state.getTimingData()}));
                    if (sendShardFailure) {
                        indexShard.failShard("replication failure", e);
                    }
                    try {
                        channel.sendResponse(e);
                    }
                    catch (IOException ex) {
                        throw new RuntimeException(ex);
                    }
                }
            });
        }
    }

    public static interface SegmentReplicationListener
    extends ReplicationListener {
        @Override
        default public void onDone(ReplicationState state) {
            this.onReplicationDone((SegmentReplicationState)state);
        }

        @Override
        default public void onFailure(ReplicationState state, ReplicationFailedException e, boolean sendShardFailure) {
            this.onReplicationFailure((SegmentReplicationState)state, e, sendShardFailure);
        }

        public void onReplicationDone(SegmentReplicationState var1);

        public void onReplicationFailure(SegmentReplicationState var1, ReplicationFailedException var2, boolean var3);
    }

    private class ReplicationRunner
    implements Runnable {
        final long replicationId;

        public ReplicationRunner(long replicationId) {
            this.replicationId = replicationId;
        }

        @Override
        public void run() {
            SegmentReplicationTargetService.this.start(this.replicationId);
        }
    }
}

