/******************************************************************************
 * Copyright (c) 2016, NVIDIA CORPORATION.  All rights reserved.
 * Modifications Copyright (c) 2019-2025, Advanced Micro Devices, Inc.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/
#pragma once

#include <thrust/detail/config.h>

#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_HIP

#  include <thrust/system/hip/config.h>

#  include <thrust/detail/minmax.h>
#  include <thrust/detail/mpl/math.h>
#  include <thrust/detail/range/head_flags.h>
#  include <thrust/detail/temporary_array.h>
#  include <thrust/distance.h>
#  include <thrust/extrema.h>
#  include <thrust/functional.h>
#  include <thrust/merge.h>
#  include <thrust/pair.h>
#  include <thrust/system/hip/detail/get_value.h>
#  include <thrust/system/hip/detail/par_to_seq.h>
#  include <thrust/system/hip/detail/util.h>

#  include <cstdint>

// rocPRIM includes
#  include <rocprim/rocprim.hpp>

THRUST_NAMESPACE_BEGIN
namespace hip_rocprim
{
namespace __merge
{
template <class KeyType, class ValueType, class Predicate>
struct predicate_wrapper
{
  Predicate predicate;
  using pair_type = rocprim::tuple<KeyType, ValueType>;

  THRUST_HIP_FUNCTION
  predicate_wrapper(Predicate p)
      : predicate(p)
  {}

  bool THRUST_HIP_DEVICE_FUNCTION operator()(pair_type const& lhs, pair_type const& rhs) const
  {
    return predicate(rocprim::get<0>(lhs), rocprim::get<0>(rhs));
  }
}; // struct predicate_wrapper

template <class Derived, class KeysIt1, class KeysIt2, class ResultIt, class CompareOp>
ResultIt THRUST_HIP_RUNTIME_FUNCTION merge(
  execution_policy<Derived>& policy,
  KeysIt1 keys1_first,
  KeysIt1 keys1_last,
  KeysIt2 keys2_first,
  KeysIt2 keys2_last,
  ResultIt result,
  CompareOp compare_op)

{
  using size_type = size_t;

  size_type input1_size = static_cast<size_type>(thrust::distance(keys1_first, keys1_last));
  size_type input2_size = static_cast<size_type>(thrust::distance(keys2_first, keys2_last));

  if (input1_size == 0 && input2_size == 0)
  {
    return result;
  }

  size_t storage_size = 0;
  hipStream_t stream  = hip_rocprim::stream(policy);
  bool debug_sync     = THRUST_HIP_DEBUG_SYNC_FLAG;

  // Determine temporary device storage requirements.
  hip_rocprim::throw_on_error(
    rocprim::merge(
      nullptr, storage_size, keys1_first, keys2_first, result, input1_size, input2_size, compare_op, stream, debug_sync),
    "merge failed on 1st step");

  // Allocate temporary storage.
  thrust::detail::temporary_array<std::uint8_t, Derived> tmp(policy, storage_size);
  void* ptr = static_cast<void*>(tmp.data().get());

  hip_rocprim::throw_on_error(
    rocprim::merge(
      ptr, storage_size, keys1_first, keys2_first, result, input1_size, input2_size, compare_op, stream, debug_sync),
    "merge failed on 2nd step");
  hip_rocprim::throw_on_error(hip_rocprim::synchronize_optional(policy), "merge: failed to synchronize");

  ResultIt result_end = result + input1_size + input2_size;
  return result_end;
}

template <typename Derived,
          typename KeysIt1,
          typename KeysIt2,
          typename ItemsIt1,
          typename ItemsIt2,
          typename KeysOutputIt,
          typename ItemsOutputIt,
          typename CompareOp>
THRUST_HIP_RUNTIME_FUNCTION pair<KeysOutputIt, ItemsOutputIt> merge(
  execution_policy<Derived>& policy,
  KeysIt1 keys1_first,
  KeysIt1 keys1_last,
  KeysIt2 keys2_first,
  KeysIt2 keys2_last,
  ItemsIt1 items1_first,
  ItemsIt2 items2_first,
  KeysOutputIt keys_result,
  ItemsOutputIt items_result,
  CompareOp compare_op)
{
  using size_type = size_t;

  using KeyType   = typename iterator_traits<KeysIt1>::value_type;
  using ValueType = typename iterator_traits<ItemsIt1>::value_type;

  predicate_wrapper<KeyType, ValueType, CompareOp> wrapped_binary_pred(compare_op);

  size_type input1_size = static_cast<size_type>(thrust::distance(keys1_first, keys1_last));
  size_type input2_size = static_cast<size_type>(thrust::distance(keys2_first, keys2_last));

  if (input1_size == 0 && input2_size == 0)
  {
    return thrust::make_pair(keys_result, items_result);
  };

  size_t storage_size = 0;
  hipStream_t stream  = hip_rocprim::stream(policy);
  bool debug_sync     = THRUST_HIP_DEBUG_SYNC_FLAG;

  // Determine temporary device storage requirements.
  hip_rocprim::throw_on_error(
    rocprim::merge(
      nullptr,
      storage_size,
      rocprim::make_zip_iterator(rocprim::make_tuple(keys1_first, items1_first)),
      rocprim::make_zip_iterator(rocprim::make_tuple(keys2_first, items2_first)),
      rocprim::make_zip_iterator(rocprim::make_tuple(keys_result, items_result)),
      input1_size,
      input2_size,
      wrapped_binary_pred,
      stream,
      debug_sync),
    "merge_by_key failed on 1st step");

  // Allocate temporary storage.
  thrust::detail::temporary_array<std::uint8_t, Derived> tmp(policy, storage_size);
  void* ptr = static_cast<void*>(tmp.data().get());

  hip_rocprim::throw_on_error(
    rocprim::merge(
      ptr,
      storage_size,
      rocprim::make_zip_iterator(rocprim::make_tuple(keys1_first, items1_first)),
      rocprim::make_zip_iterator(rocprim::make_tuple(keys2_first, items2_first)),
      rocprim::make_zip_iterator(rocprim::make_tuple(keys_result, items_result)),
      input1_size,
      input2_size,
      wrapped_binary_pred,
      stream,
      debug_sync),
    "merge_by_key failed on 2nd step");
  hip_rocprim::throw_on_error(hip_rocprim::synchronize_optional(policy), "merge: failed to synchronize");

  size_t count = input1_size + input2_size;
  return thrust::make_pair(keys_result + count, items_result + count);
}

} // namespace __merge

//-------------------------
// Thrust API entry points
//-------------------------
THRUST_EXEC_CHECK_DISABLE
template <class Derived, class KeysIt1, class KeysIt2, class ResultIt, class CompareOp>
ResultIt THRUST_HIP_FUNCTION merge(
  execution_policy<Derived>& policy,
  KeysIt1 keys1_first,
  KeysIt1 keys1_last,
  KeysIt2 keys2_first,
  KeysIt2 keys2_last,
  ResultIt result,
  CompareOp compare_op)

{
  // struct workaround is required for HIP-clang
  struct workaround
  {
    THRUST_HOST static ResultIt
    par(execution_policy<Derived>& policy,
        KeysIt1 keys1_first,
        KeysIt1 keys1_last,
        KeysIt2 keys2_first,
        KeysIt2 keys2_last,
        ResultIt result,
        CompareOp compare_op)
    {
      return __merge::merge(policy, keys1_first, keys1_last, keys2_first, keys2_last, result, compare_op);
    }
    THRUST_DEVICE static ResultIt
    seq(execution_policy<Derived>& policy,
        KeysIt1 keys1_first,
        KeysIt1 keys1_last,
        KeysIt2 keys2_first,
        KeysIt2 keys2_last,
        ResultIt result,
        CompareOp compare_op)
    {
      return thrust::merge(
        cvt_to_seq(derived_cast(policy)), keys1_first, keys1_last, keys2_first, keys2_last, result, compare_op);
    }
  };
#  if __THRUST_HAS_HIPRT__
  return workaround::par(policy, keys1_first, keys1_last, keys2_first, keys2_last, result, compare_op);
#  else
  return workaround::seq(policy, keys1_first, keys1_last, keys2_first, keys2_last, result, compare_op);
#  endif
}

THRUST_EXEC_CHECK_DISABLE
template <class Derived,
          class KeysIt1,
          class KeysIt2,
          class ItemsIt1,
          class ItemsIt2,
          class KeysOutputIt,
          class ItemsOutputIt,
          class CompareOp>
pair<KeysOutputIt, ItemsOutputIt> THRUST_HIP_FUNCTION merge_by_key(
  execution_policy<Derived>& policy,
  KeysIt1 keys1_first,
  KeysIt1 keys1_last,
  KeysIt2 keys2_first,
  KeysIt2 keys2_last,
  ItemsIt1 items1_first,
  ItemsIt2 items2_first,
  KeysOutputIt keys_result,
  ItemsOutputIt items_result,
  CompareOp compare_op)
{
  // struct workaround is required for HIP-clang
  struct workaround
  {
    THRUST_HOST static pair<KeysOutputIt, ItemsOutputIt> par(
      execution_policy<Derived>& policy,
      KeysIt1 keys1_first,
      KeysIt1 keys1_last,
      KeysIt2 keys2_first,
      KeysIt2 keys2_last,
      ItemsIt1 items1_first,
      ItemsIt2 items2_first,
      KeysOutputIt keys_result,
      ItemsOutputIt items_result,
      CompareOp compare_op)
    {
      return __merge::merge(
        policy,
        keys1_first,
        keys1_last,
        keys2_first,
        keys2_last,
        items1_first,
        items2_first,
        keys_result,
        items_result,
        compare_op);
    }
    THRUST_DEVICE static pair<KeysOutputIt, ItemsOutputIt> seq(
      execution_policy<Derived>& policy,
      KeysIt1 keys1_first,
      KeysIt1 keys1_last,
      KeysIt2 keys2_first,
      KeysIt2 keys2_last,
      ItemsIt1 items1_first,
      ItemsIt2 items2_first,
      KeysOutputIt keys_result,
      ItemsOutputIt items_result,
      CompareOp compare_op)
    {
      return thrust::merge_by_key(
        cvt_to_seq(derived_cast(policy)),
        keys1_first,
        keys1_last,
        keys2_first,
        keys2_last,
        items1_first,
        items2_first,
        keys_result,
        items_result,
        compare_op);
    }
  };

#  if __THRUST_HAS_HIPRT__
  return workaround::par(
    policy,
    keys1_first,
    keys1_last,
    keys2_first,
    keys2_last,
    items1_first,
    items2_first,
    keys_result,
    items_result,
    compare_op);
#  else
  return workaround::seq(
    policy,
    keys1_first,
    keys1_last,
    keys2_first,
    keys2_last,
    items1_first,
    items2_first,
    keys_result,
    items_result,
    compare_op);
#  endif
}

THRUST_EXEC_CHECK_DISABLE
template <class Derived, class KeysIt1, class KeysIt2, class ResultIt>
ResultIt THRUST_HIP_FUNCTION merge(
  execution_policy<Derived>& policy,
  KeysIt1 keys1_first,
  KeysIt1 keys1_last,
  KeysIt2 keys2_first,
  KeysIt2 keys2_last,
  ResultIt result)
{
  using keys_type = typename thrust::iterator_value<KeysIt1>::type;
  return hip_rocprim::merge(policy, keys1_first, keys1_last, keys2_first, keys2_last, result, less<keys_type>());
}

template <class Derived, class KeysIt1, class KeysIt2, class ItemsIt1, class ItemsIt2, class KeysOutputIt, class ItemsOutputIt>
pair<KeysOutputIt, ItemsOutputIt> THRUST_HIP_FUNCTION merge_by_key(
  execution_policy<Derived>& policy,
  KeysIt1 keys1_first,
  KeysIt1 keys1_last,
  KeysIt2 keys2_first,
  KeysIt2 keys2_last,
  ItemsIt1 items1_first,
  ItemsIt2 items2_first,
  KeysOutputIt keys_result,
  ItemsOutputIt items_result)
{
  using keys_type = typename thrust::iterator_value<KeysIt1>::type;
  return hip_rocprim::merge_by_key(
    policy,
    keys1_first,
    keys1_last,
    keys2_first,
    keys2_last,
    items1_first,
    items2_first,
    keys_result,
    items_result,
    thrust::less<keys_type>());
}

} // namespace hip_rocprim

THRUST_NAMESPACE_END
#endif
