/* SPDX-FileCopyrightText: 2024 Blender Authors
 *
 * SPDX-License-Identifier: GPL-2.0-or-later */

#include "editors/sculpt_paint/brushes/brushes.hh"

#include "DNA_brush_types.h"
#include "DNA_mesh_types.h"
#include "DNA_object_types.h"
#include "DNA_scene_types.h"

#include "BKE_mesh.hh"
#include "BKE_paint.hh"
#include "BKE_paint_bvh.hh"
#include "BKE_subdiv_ccg.hh"

#include "BLI_enumerable_thread_specific.hh"
#include "BLI_math_geom.h"
#include "BLI_math_matrix.h"
#include "BLI_math_matrix.hh"
#include "BLI_math_rotation.h"
#include "BLI_math_rotation_legacy.hh"
#include "BLI_math_vector.hh"
#include "BLI_task.hh"

#include "editors/sculpt_paint/mesh_brush_common.hh"
#include "editors/sculpt_paint/sculpt_intern.hh"

#include "bmesh.hh"

#include <numeric>

namespace blender::ed::sculpt_paint::brushes {

inline namespace clay_thumb_cc {

struct LocalData {
  Vector<float3> positions;
  Vector<float> factors;
  Vector<float> distances;
  Vector<float3> translations;
};

static void calc_faces(const Depsgraph &depsgraph,
                       const Sculpt &sd,
                       const Brush &brush,
                       const float4 &plane_tilt,
                       const float strength,
                       const MeshAttributeData &attribute_data,
                       const Span<float3> vert_normals,
                       const bke::pbvh::MeshNode &node,
                       Object &object,
                       LocalData &tls,
                       const PositionDeformData &position_data)
{
  SculptSession &ss = *object.sculpt;

  const Span<int> verts = node.verts();
  const MutableSpan positions = gather_data_mesh(position_data.eval, verts, tls.positions);

  calc_factors_common_mesh(depsgraph,
                           brush,
                           object,
                           attribute_data,
                           positions,
                           vert_normals,
                           node,
                           tls.factors,
                           tls.distances);

  scale_factors(tls.factors, strength);

  tls.translations.resize(verts.size());
  const MutableSpan<float3> translations = tls.translations;
  calc_translations_to_plane(positions, plane_tilt, translations);

  scale_translations(translations, tls.factors);

  clip_and_lock_translations(sd, ss, positions, translations);
  position_data.deform(translations, verts);
}

static void calc_grids(const Depsgraph &depsgraph,
                       const Sculpt &sd,
                       const Brush &brush,
                       const float4 &plane_tilt,
                       const float strength,
                       const bke::pbvh::GridsNode &node,
                       Object &object,
                       LocalData &tls)
{
  SculptSession &ss = *object.sculpt;
  SubdivCCG &subdiv_ccg = *ss.subdiv_ccg;

  const Span<int> grids = node.grids();
  const MutableSpan positions = gather_grids_positions(subdiv_ccg, grids, tls.positions);

  calc_factors_common_grids(depsgraph, brush, object, positions, node, tls.factors, tls.distances);

  scale_factors(tls.factors, strength);

  tls.translations.resize(positions.size());
  const MutableSpan<float3> translations = tls.translations;
  calc_translations_to_plane(positions, plane_tilt, translations);

  scale_translations(translations, tls.factors);

  clip_and_lock_translations(sd, ss, positions, translations);
  apply_translations(translations, grids, subdiv_ccg);
}

static void calc_bmesh(const Depsgraph &depsgraph,
                       const Sculpt &sd,
                       const Brush &brush,
                       const float4 &plane_tilt,
                       const float strength,
                       Object &object,
                       bke::pbvh::BMeshNode &node,
                       LocalData &tls)
{
  SculptSession &ss = *object.sculpt;

  const Set<BMVert *, 0> &verts = BKE_pbvh_bmesh_node_unique_verts(&node);
  const MutableSpan positions = gather_bmesh_positions(verts, tls.positions);

  calc_factors_common_bmesh(depsgraph, brush, object, positions, node, tls.factors, tls.distances);

  scale_factors(tls.factors, strength);

  tls.translations.resize(verts.size());
  const MutableSpan<float3> translations = tls.translations;
  calc_translations_to_plane(positions, plane_tilt, translations);

  scale_translations(translations, tls.factors);

  clip_and_lock_translations(sd, ss, positions, translations);
  apply_translations(translations, verts);
}

}  // namespace clay_thumb_cc

void do_clay_thumb_brush(const Depsgraph &depsgraph,
                         const Sculpt &sd,
                         Object &object,
                         const IndexMask &node_mask)
{
  const SculptSession &ss = *object.sculpt;
  bke::pbvh::Tree &pbvh = *bke::object::pbvh_get(object);
  const Brush &brush = *BKE_paint_brush_for_read(&sd.paint);
  const float3 &location = ss.cache->location_symm;

  float3 area_position;
  float3 sculpt_plane_normal;
  calc_brush_plane(depsgraph, brush, object, node_mask, area_position, sculpt_plane_normal);

  float3 area_normal = area_position;
  /* Ignore brush settings and recalculate the area normal. */
  if (brush.sculpt_plane != SCULPT_DISP_DIR_AREA || (brush.flag & BRUSH_ORIGINAL_NORMAL)) {
    area_normal = calc_area_normal(depsgraph, brush, object, node_mask).value_or(float3(0));
  }

  /* Delay the first daub because grab delta is not setup. */
  if (SCULPT_stroke_is_first_brush_step_of_symmetry_pass(*ss.cache)) {
    ss.cache->clay_thumb_brush.front_angle = 0.0f;
    return;
  }

  /* Simulate the clay accumulation by increasing the plane angle as more samples are added to the
   * stroke. */
  if (SCULPT_stroke_is_main_symmetry_pass(*ss.cache)) {
    ss.cache->clay_thumb_brush.front_angle += 0.8f;
    ss.cache->clay_thumb_brush.front_angle = std::clamp(
        ss.cache->clay_thumb_brush.front_angle, 0.0f, 60.0f);
  }

  if (math::is_zero(ss.cache->grab_delta_symm)) {
    return;
  }

  /* Initialize brush local-space matrix. */
  float4x4 mat = float4x4::identity();
  mat.x_axis() = math::cross(area_normal, ss.cache->grab_delta_symm);
  mat.y_axis() = math::cross(area_normal, mat.x_axis());
  mat.z_axis() = area_normal;
  mat.location() = ss.cache->location_symm;
  /* NOTE: #math::normalize behaves differently for some reason. */
  normalize_m4(mat.ptr());

  /* Scale brush local space matrix. */
  float4x4 scale = math::from_scale<float4x4>(float3(ss.cache->radius));
  const float4x4 tmat = mat * scale;

  float3 normal_tilt;
  rotate_v3_v3v3fl(normal_tilt,
                   area_position,
                   tmat.x_axis(),
                   DEG2RADF(-ss.cache->clay_thumb_brush.front_angle));

  /* Tilted plane (front part of the brush). */
  float4 plane_tilt;
  plane_from_point_normal_v3(plane_tilt, location, normal_tilt);

  const float clay_strength = ss.cache->bstrength * clay_thumb_get_stabilized_pressure(*ss.cache);
  threading::EnumerableThreadSpecific<LocalData> all_tls;
  switch (pbvh.type()) {
    case bke::pbvh::Type::Mesh: {
      const Mesh &mesh = *static_cast<Mesh *>(object.data);
      const MeshAttributeData attribute_data(mesh);
      MutableSpan<bke::pbvh::MeshNode> nodes = pbvh.nodes<bke::pbvh::MeshNode>();
      const PositionDeformData position_data(depsgraph, object);
      const Span<float3> vert_normals = bke::pbvh::vert_normals_eval(depsgraph, object);
      node_mask.foreach_index(GrainSize(1), [&](const int i) {
        LocalData &tls = all_tls.local();
        calc_faces(depsgraph,
                   sd,
                   brush,
                   plane_tilt,
                   clay_strength,
                   attribute_data,
                   vert_normals,
                   nodes[i],
                   object,
                   tls,
                   position_data);
        bke::pbvh::update_node_bounds_mesh(position_data.eval, nodes[i]);
      });
      break;
    }
    case bke::pbvh::Type::Grids: {
      SubdivCCG &subdiv_ccg = *object.sculpt->subdiv_ccg;
      MutableSpan<float3> positions = subdiv_ccg.positions;
      MutableSpan<bke::pbvh::GridsNode> nodes = pbvh.nodes<bke::pbvh::GridsNode>();
      node_mask.foreach_index(GrainSize(1), [&](const int i) {
        LocalData &tls = all_tls.local();
        calc_grids(depsgraph, sd, brush, plane_tilt, clay_strength, nodes[i], object, tls);
        bke::pbvh::update_node_bounds_grids(subdiv_ccg.grid_area, positions, nodes[i]);
      });
      break;
    }
    case bke::pbvh::Type::BMesh: {
      MutableSpan<bke::pbvh::BMeshNode> nodes = pbvh.nodes<bke::pbvh::BMeshNode>();
      node_mask.foreach_index(GrainSize(1), [&](const int i) {
        LocalData &tls = all_tls.local();
        calc_bmesh(depsgraph, sd, brush, plane_tilt, clay_strength, object, nodes[i], tls);
        bke::pbvh::update_node_bounds_bmesh(nodes[i]);
      });
      break;
    }
  }
  pbvh.tag_positions_changed(node_mask);
  pbvh.flush_bounds_to_parents();
}

float clay_thumb_get_stabilized_pressure(const StrokeCache &cache)
{
  const float pressure_sum = std::accumulate(cache.clay_thumb_brush.pressure_stabilizer.begin(),
                                             cache.clay_thumb_brush.pressure_stabilizer.end(),
                                             0.0f);
  return pressure_sum / cache.clay_thumb_brush.pressure_stabilizer.size();
}

}  // namespace blender::ed::sculpt_paint::brushes
