shader_type spatial;
render_mode unshaded;

uniform sampler2D screen_texture : source_color, hint_screen_texture, filter_nearest;
uniform sampler2D depth_texture : source_color, hint_depth_texture, filter_nearest;
uniform sampler2D normal_texture : source_color, hint_normal_roughness_texture, filter_nearest;

uniform bool depth_layer = true;
uniform bool normal_layer = true;
uniform bool original_layer = true;

uniform float depth_edge_strength : hint_range(-1, 1, 0.01) = -0.5;
uniform float normal_edge_strength : hint_range(-1, 1, 0.01) = 0.5;

uniform float depth_edge_threshold : hint_range(0, 1) = 0.2;
uniform float normal_edge_threshold : hint_range(0, 1) = 0.6;

uniform vec3 normal_edge_bias = vec3(1, 1, 1);

float get_depth(vec2 screen_uv, mat4 inv_projection_matrix) {
	float depth = texture(depth_texture, screen_uv).r;
	vec3 ndc = vec3(screen_uv * 2.0 - 1.0, depth);
	vec4 view = inv_projection_matrix * vec4(ndc, 1.0);
	view.xyz /= view.w;
	return -view.z;
}

vec2[4] create_uvs(vec2 screen_uv, vec2 texel_size) {
	vec2 uvs[4];
	uvs[0] = vec2(screen_uv.x, min(1.0 - texel_size.y, screen_uv.y + texel_size.y));
	uvs[1] = vec2(screen_uv.x, max(0.0, screen_uv.y - texel_size.y));
	uvs[2] = vec2(min(1.0 - texel_size.x, screen_uv.x + texel_size.x), screen_uv.y);
	uvs[3] = vec2(max(0.0, screen_uv.x - texel_size.x), screen_uv.y);
	return uvs;
}

float normal_detection(vec2 uv, vec3 normal) {
	vec3 neighbor_normal = texture(normal_texture, uv).xyz * 2.0 - 1.0;
	vec3 normal_diff = normal - neighbor_normal;
	float normal_bias_diff = dot(normal_diff, normal_edge_bias);
	float normal_indicator = smoothstep(0.01, -0.01, normal_bias_diff);
	return dot(normal_diff, normal_diff) * normal_indicator;
}

void vertex() {
	POSITION = vec4(VERTEX, 1.0);
}

void fragment() {
	vec3 original = texture(screen_texture, SCREEN_UV).rgb;
	vec3 normal = texture(normal_texture, SCREEN_UV).xyz * 2.0 - 1.0;
	float depth = get_depth(SCREEN_UV, INV_PROJECTION_MATRIX);
	vec2 texel_size = 1.0 / VIEWPORT_SIZE.xy;
	vec2 neighbors[4] = create_uvs(SCREEN_UV, texel_size);

	float depth_diff = 0.0;
	float normal_diff = 0.0;

	for (int i = 0; i < 4; i++) {
		float neighbor_depth = get_depth(neighbors[i], INV_PROJECTION_MATRIX);
		depth_diff += neighbor_depth - depth;
		normal_diff += normal_detection(neighbors[i], normal);
	}

	float strength = 0.0;

	if (depth_layer && depth_diff > depth_edge_threshold) {
		strength = 1.0 + depth_edge_strength;
	} else if (normal_layer && normal_diff > normal_edge_threshold && depth_diff > 0.0) {
		strength = 1.0 + normal_edge_strength;
	} else if(original_layer) {
		strength = 1.0;
	}

	ALBEDO = original * strength;
}