/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.indices.replication;

import java.io.IOException;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.index.CorruptIndexException;
import org.opensearch.OpenSearchCorruptionException;
import org.opensearch.common.SetOnce;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.store.Store;
import org.opensearch.indices.replication.SegmentReplicationSource;
import org.opensearch.indices.replication.SegmentReplicationSourceFactory;
import org.opensearch.indices.replication.SegmentReplicationState;
import org.opensearch.indices.replication.SegmentReplicationTarget;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.indices.replication.common.ReplicationCollection;
import org.opensearch.indices.replication.common.ReplicationFailedException;
import org.opensearch.indices.replication.common.ReplicationListener;
import org.opensearch.threadpool.ThreadPool;

public class SegmentReplicator {
    private static final Logger logger = LogManager.getLogger(SegmentReplicator.class);
    private final ReplicationCollection<SegmentReplicationTarget> onGoingReplications;
    private final Map<ShardId, SegmentReplicationState> completedReplications = ConcurrentCollections.newConcurrentMap();
    private final ThreadPool threadPool;
    private final SetOnce<SegmentReplicationSourceFactory> sourceFactory;

    public SegmentReplicator(ThreadPool threadPool) {
        this.onGoingReplications = new ReplicationCollection(logger, threadPool);
        this.threadPool = threadPool;
        this.sourceFactory = new SetOnce();
    }

    public void startReplication(final IndexShard shard) {
        if (this.sourceFactory.get() == null) {
            return;
        }
        this.startReplication(shard, shard.getLatestReplicationCheckpoint(), this.sourceFactory.get().get(shard), new SegmentReplicationTargetService.SegmentReplicationListener(){

            @Override
            public void onReplicationDone(SegmentReplicationState state) {
                logger.trace("Completed replication for {}", (Object)shard.shardId());
            }

            @Override
            public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) {
                logger.error(() -> new ParameterizedMessage("Failed segment replication for {}", (Object)shard.shardId()), (Throwable)e);
                if (sendShardFailure) {
                    shard.failShard("unrecoverable replication failure", e);
                }
            }
        });
    }

    void setSourceFactory(SegmentReplicationSourceFactory sourceFactory) {
        this.sourceFactory.set(sourceFactory);
    }

    SegmentReplicationTarget startReplication(IndexShard indexShard, ReplicationCheckpoint checkpoint, SegmentReplicationSource source, SegmentReplicationTargetService.SegmentReplicationListener listener) {
        SegmentReplicationTarget target = new SegmentReplicationTarget(indexShard, checkpoint, source, (ReplicationListener)listener);
        this.startReplication(target, indexShard.getRecoverySettings().activityTimeout());
        return target;
    }

    private void start(final long replicationId) {
        SegmentReplicationTarget target;
        try (ReplicationCollection.ReplicationRef<SegmentReplicationTarget> replicationRef = this.onGoingReplications.get(replicationId);){
            if (replicationRef == null) {
                return;
            }
            target = (SegmentReplicationTarget)replicationRef.get();
        }
        target.startReplication(new ActionListener<Void>(){

            @Override
            public void onResponse(Void o) {
                logger.debug(() -> new ParameterizedMessage("Finished replicating {} marking as done.", (Object)target.description()));
                SegmentReplicator.this.onGoingReplications.markAsDone(replicationId);
                if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0L) {
                    SegmentReplicator.this.completedReplications.put(target.shardId(), target.state());
                }
            }

            @Override
            public void onFailure(Exception e) {
                logger.debug("Replication failed {}", (Object)target.description());
                if (SegmentReplicator.this.isStoreCorrupt(target) || e instanceof CorruptIndexException || e instanceof OpenSearchCorruptionException) {
                    SegmentReplicator.this.onGoingReplications.fail(replicationId, new ReplicationFailedException("Store corruption during replication", (Throwable)e), true);
                    return;
                }
                SegmentReplicator.this.onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", (Throwable)e), false);
            }
        });
    }

    void startReplication(SegmentReplicationTarget target, TimeValue timeout) {
        long replicationId;
        try {
            replicationId = this.onGoingReplications.startSafe(target, timeout);
        }
        catch (ReplicationFailedException e) {
            target.fail(e, false);
            return;
        }
        logger.trace(() -> new ParameterizedMessage("Added new replication to collection {}", (Object)target.description()));
        this.threadPool.generic().execute(new ReplicationRunner(replicationId));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean isStoreCorrupt(SegmentReplicationTarget target) {
        Store store;
        if (target.refCount() > 0 && (store = target.store()).tryIncRef()) {
            try {
                boolean bl = store.isMarkedCorrupted();
                return bl;
            }
            catch (IOException ex) {
                logger.warn("Unable to determine if store is corrupt", (Throwable)ex);
                boolean bl = false;
                return bl;
            }
            finally {
                store.decRef();
            }
        }
        return false;
    }

    int size() {
        return this.onGoingReplications.size();
    }

    void cancel(ShardId shardId, String reason) {
        this.onGoingReplications.cancelForShard(shardId, reason);
    }

    SegmentReplicationTarget get(ShardId shardId) {
        return this.onGoingReplications.getOngoingReplicationTarget(shardId);
    }

    ReplicationCollection.ReplicationRef<SegmentReplicationTarget> get(long id) {
        return this.onGoingReplications.get(id);
    }

    SegmentReplicationState getCompleted(ShardId shardId) {
        return this.completedReplications.get(shardId);
    }

    ReplicationCollection.ReplicationRef<SegmentReplicationTarget> get(long id, ShardId shardId) {
        return this.onGoingReplications.getSafe(id, shardId);
    }

    private class ReplicationRunner
    extends AbstractRunnable {
        final long replicationId;

        public ReplicationRunner(long replicationId) {
            this.replicationId = replicationId;
        }

        @Override
        public void onFailure(Exception e) {
            SegmentReplicator.this.onGoingReplications.fail(this.replicationId, new ReplicationFailedException("Unexpected Error during replication", (Throwable)e), false);
        }

        @Override
        public void doRun() {
            SegmentReplicator.this.start(this.replicationId);
        }
    }
}

