/*
 * Copyright 2023 LiveKit, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package rtc

import (
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"go.uber.org/atomic"

	"github.com/livekit/livekit-server/pkg/rtc/types"
	"github.com/livekit/livekit-server/pkg/rtc/types/typesfakes"
	"github.com/livekit/livekit-server/pkg/utils"
	"github.com/livekit/protocol/livekit"
	"github.com/livekit/protocol/logger"
)

func init() {
	reconcileInterval = 50 * time.Millisecond
	notFoundTimeout = 200 * time.Millisecond
	subscriptionTimeout = 200 * time.Millisecond
}

const (
	subSettleTimeout = 600 * time.Millisecond
	subCheckInterval = 10 * time.Millisecond
)

func TestSubscribe(t *testing.T) {
	t.Run("happy path subscribe", func(t *testing.T) {
		sm := newTestSubscriptionManager()
		defer sm.Close(false)
		resolver := newTestResolver(true, true, "pub", "pubID")
		sm.params.TrackResolver = resolver.Resolve
		subCount := atomic.Int32{}
		failed := atomic.Bool{}
		sm.params.OnTrackSubscribed = func(subTrack types.SubscribedTrack) {
			subCount.Add(1)
		}
		sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) {
			failed.Store(true)
		}
		numParticipantSubscribed := atomic.Int32{}
		numParticipantUnsubscribed := atomic.Int32{}
		sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) {
			if subscribed {
				numParticipantSubscribed.Add(1)
			} else {
				numParticipantUnsubscribed.Add(1)
			}
		})

		sm.SubscribeToTrack("track", false)
		s := sm.subscriptions["track"]
		require.True(t, s.isDesired())
		require.Eventually(t, func() bool {
			return subCount.Load() == 1
		}, subSettleTimeout, subCheckInterval, "track was not subscribed")

		require.NotNil(t, s.getSubscribedTrack())
		require.Len(t, sm.GetSubscribedTracks(), 1)

		require.Eventually(t, func() bool {
			return len(sm.GetSubscribedParticipants()) == 1
		}, subSettleTimeout, subCheckInterval, "GetSubscribedParticipants should have returned one item")
		require.Equal(t, "pubID", string(sm.GetSubscribedParticipants()[0]))

		// ensure telemetry events are sent
		tl := sm.params.TelemetryListener.(*typesfakes.FakeParticipantTelemetryListener)
		require.Equal(t, 1, tl.OnTrackSubscribeRequestedCallCount())

		// ensure bound
		setTestSubscribedTrackBound(t, s.getSubscribedTrack())
		require.Eventually(t, func() bool {
			return !s.needsBind()
		}, subSettleTimeout, subCheckInterval, "track was not bound")

		// telemetry event should have been sent
		require.Equal(t, 1, tl.OnTrackSubscribedCallCount())

		time.Sleep(notFoundTimeout)
		require.False(t, failed.Load())

		resolver.SetPause(true)
		// ensure its resilience after being closed
		setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false)
		require.Eventually(t, func() bool {
			return s.needsSubscribe()
		}, subSettleTimeout, subCheckInterval, "needs subscribe did not persist across track close")
		resolver.SetPause(false)

		require.Eventually(t, func() bool {
			return s.isDesired() && !s.needsSubscribe()
		}, subSettleTimeout, subCheckInterval, "track was not resubscribed")

		// was subscribed twice, unsubscribed once (due to close)
		require.Eventually(t, func() bool {
			return numParticipantSubscribed.Load() == 2
		}, subSettleTimeout, subCheckInterval, "participant subscribe status was not updated twice")
		require.Equal(t, int32(1), numParticipantUnsubscribed.Load())
	})

	t.Run("no track permission", func(t *testing.T) {
		sm := newTestSubscriptionManager()
		defer sm.Close(false)
		resolver := newTestResolver(false, true, "pub", "pubID")
		sm.params.TrackResolver = resolver.Resolve
		failed := atomic.Bool{}
		sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) {
			failed.Store(true)
		}

		sm.SubscribeToTrack("track", false)
		s := sm.subscriptions["track"]
		require.Eventually(t, func() bool {
			return !s.getHasPermission()
		}, subSettleTimeout, subCheckInterval, "should not have permission to subscribe")

		time.Sleep(subscriptionTimeout)

		// should not have called failed callbacks, isDesired remains unchanged
		require.True(t, s.isDesired())
		require.False(t, failed.Load())
		require.True(t, s.needsSubscribe())
		require.Len(t, sm.GetSubscribedTracks(), 0)

		// trackSubscribed telemetry not sent
		tl := sm.params.TelemetryListener.(*typesfakes.FakeParticipantTelemetryListener)
		require.Equal(t, 1, tl.OnTrackSubscribeRequestedCallCount())
		require.Equal(t, 0, tl.OnTrackSubscribedCallCount())

		// give permissions now
		resolver.lock.Lock()
		resolver.hasPermission = true
		resolver.lock.Unlock()

		require.Eventually(t, func() bool {
			return !s.needsSubscribe()
		}, subSettleTimeout, subCheckInterval, "should be subscribed")

		require.Len(t, sm.GetSubscribedTracks(), 1)
	})

	t.Run("publisher left", func(t *testing.T) {
		sm := newTestSubscriptionManager()
		defer sm.Close(false)
		resolver := newTestResolver(true, true, "pub", "pubID")
		sm.params.TrackResolver = resolver.Resolve
		failed := atomic.Bool{}
		sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) {
			failed.Store(true)
		}

		sm.SubscribeToTrack("track", false)
		s := sm.subscriptions["track"]
		require.Eventually(t, func() bool {
			return !s.needsSubscribe()
		}, subSettleTimeout, subCheckInterval, "should be subscribed")

		resolver.lock.Lock()
		resolver.hasTrack = false
		resolver.lock.Unlock()

		// publisher triggers close
		setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false)

		require.Eventually(t, func() bool {
			return !s.isDesired()
		}, subSettleTimeout, subCheckInterval, "isDesired not set to false")
	})
}

func TestUnsubscribe(t *testing.T) {
	sm := newTestSubscriptionManager()
	defer sm.Close(false)
	unsubCount := atomic.Int32{}
	sm.params.OnTrackUnsubscribed = func(subTrack types.SubscribedTrack) {
		unsubCount.Add(1)
	}

	resolver := newTestResolver(true, true, "pub", "pubID")

	s := &mediaTrackSubscription{
		trackSubscription: trackSubscription{
			trackID:           "track",
			desired:           true,
			subscriberID:      sm.params.Participant.ID(),
			publisherID:       "pubID",
			publisherIdentity: "pub",
			logger:            logger.GetLogger(),
		},
		hasPermission: true,
		bound:         true,
	}
	// a bunch of unfortunate manual wiring
	res := resolver.Resolve(nil, s.trackID)
	res.TrackChangedNotifier.AddObserver(string(sm.params.Participant.ID()), func() {})
	s.changedNotifier = res.TrackChangedNotifier
	st, err := res.Track.AddSubscriber(sm.params.Participant)
	require.NoError(t, err)
	s.subscribedTrack = st
	st.OnClose(func(isExpectedToResume bool) {
		sm.handleSubscribedTrackClose(s, isExpectedToResume)
	})
	res.Track.(*typesfakes.FakeMediaTrack).RemoveSubscriberCalls(func(pID livekit.ParticipantID, isExpectedToResume bool) {
		setTestSubscribedTrackClosed(t, st, isExpectedToResume)
	})

	sm.lock.Lock()
	sm.subscriptions["track"] = s
	sm.lock.Unlock()

	require.False(t, s.needsSubscribe())
	require.False(t, s.needsUnsubscribe())

	// unsubscribe
	sm.UnsubscribeFromTrack("track")
	require.False(t, s.isDesired())

	require.Eventually(t, func() bool {
		if s.needsUnsubscribe() {
			return false
		}
		if sm.pendingUnsubscribes.Load() != 0 {
			return false
		}
		sm.lock.RLock()
		subLen := len(sm.subscriptions)
		sm.lock.RUnlock()
		if subLen != 0 {
			return false
		}
		return true
	}, subSettleTimeout, subCheckInterval, "Track was not unsubscribed")

	// no traces should be left
	require.Len(t, sm.GetSubscribedTracks(), 0)
	require.False(t, res.TrackChangedNotifier.HasObservers())

	tl := sm.params.TelemetryListener.(*typesfakes.FakeParticipantTelemetryListener)
	require.Equal(t, 1, tl.OnTrackUnsubscribedCallCount())
}

func TestSubscribeStatusChanged(t *testing.T) {
	sm := newTestSubscriptionManager()
	defer sm.Close(false)
	resolver := newTestResolver(true, true, "pub", "pubID")
	sm.params.TrackResolver = resolver.Resolve
	numParticipantSubscribed := atomic.Int32{}
	numParticipantUnsubscribed := atomic.Int32{}
	sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) {
		if subscribed {
			numParticipantSubscribed.Add(1)
		} else {
			numParticipantUnsubscribed.Add(1)
		}
	})

	sm.SubscribeToTrack("track1", false)
	sm.SubscribeToTrack("track2", false)
	s1 := sm.subscriptions["track1"]
	s2 := sm.subscriptions["track2"]
	require.Eventually(t, func() bool {
		return !s1.needsSubscribe() && !s2.needsSubscribe()
	}, subSettleTimeout, subCheckInterval, "track1 and track2 should be subscribed")
	st1 := s1.getSubscribedTrack()
	st1.OnClose(func(isExpectedToResume bool) {
		sm.handleSubscribedTrackClose(s1, isExpectedToResume)
	})
	st2 := s2.getSubscribedTrack()
	st2.OnClose(func(isExpectedToResume bool) {
		sm.handleSubscribedTrackClose(s2, isExpectedToResume)
	})
	st1.MediaTrack().(*typesfakes.FakeMediaTrack).RemoveSubscriberCalls(func(pID livekit.ParticipantID, isExpectedToResume bool) {
		setTestSubscribedTrackClosed(t, st1, isExpectedToResume)
	})
	st2.MediaTrack().(*typesfakes.FakeMediaTrack).RemoveSubscriberCalls(func(pID livekit.ParticipantID, isExpectedToResume bool) {
		setTestSubscribedTrackClosed(t, st2, isExpectedToResume)
	})

	require.Eventually(t, func() bool {
		return numParticipantSubscribed.Load() == 1
	}, subSettleTimeout, subCheckInterval, "should be subscribed to publisher")
	require.Equal(t, int32(0), numParticipantUnsubscribed.Load())
	require.True(t, sm.IsSubscribedTo("pubID"))

	// now unsubscribe track2, no event should be fired
	sm.UnsubscribeFromTrack("track2")
	require.Eventually(t, func() bool {
		return !s2.needsUnsubscribe()
	}, subSettleTimeout, subCheckInterval, "track2 should be unsubscribed")
	require.Equal(t, int32(0), numParticipantUnsubscribed.Load())

	// unsubscribe track1, expect event
	sm.UnsubscribeFromTrack("track1")
	require.Eventually(t, func() bool {
		return !s1.needsUnsubscribe()
	}, subSettleTimeout, subCheckInterval, "track1 should be unsubscribed")
	require.Eventually(t, func() bool {
		return numParticipantUnsubscribed.Load() == 1
	}, subSettleTimeout, subCheckInterval, "should be subscribed to publisher")
	require.False(t, sm.IsSubscribedTo("pubID"))
}

// clients may send update subscribed settings prior to subscription events coming through
// settings should be persisted and used when the subscription does take place.
func TestUpdateSettingsBeforeSubscription(t *testing.T) {
	sm := newTestSubscriptionManager()
	defer sm.Close(false)
	resolver := newTestResolver(true, true, "pub", "pubID")
	sm.params.TrackResolver = resolver.Resolve

	settings := &livekit.UpdateTrackSettings{
		Disabled: true,
		Width:    100,
		Height:   100,
	}
	sm.UpdateSubscribedTrackSettings("track", settings)

	sm.SubscribeToTrack("track", false)

	s := sm.subscriptions["track"]
	require.Eventually(t, func() bool {
		return !s.needsSubscribe()
	}, subSettleTimeout, subCheckInterval, "Track should be subscribed")

	st := s.getSubscribedTrack().(*typesfakes.FakeSubscribedTrack)
	require.Eventually(t, func() bool {
		return st.UpdateSubscriberSettingsCallCount() == 1
	}, subSettleTimeout, subCheckInterval, "UpdateSubscriberSettings should be called once")

	applied, _ := st.UpdateSubscriberSettingsArgsForCall(0)
	require.Equal(t, settings.Disabled, applied.Disabled)
	require.Equal(t, settings.Width, applied.Width)
	require.Equal(t, settings.Height, applied.Height)
}

func TestSubscriptionLimits(t *testing.T) {
	sm := newTestSubscriptionManagerWithParams(testSubscriptionParams{
		SubscriptionLimitAudio: 1,
		SubscriptionLimitVideo: 1,
	})
	defer sm.Close(false)
	resolver := newTestResolver(true, true, "pub", "pubID")
	sm.params.TrackResolver = resolver.Resolve
	subCount := atomic.Int32{}
	failed := atomic.Bool{}
	sm.params.OnTrackSubscribed = func(subTrack types.SubscribedTrack) {
		subCount.Add(1)
	}
	sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) {
		failed.Store(true)
	}
	numParticipantSubscribed := atomic.Int32{}
	numParticipantUnsubscribed := atomic.Int32{}
	sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) {
		if subscribed {
			numParticipantSubscribed.Add(1)
		} else {
			numParticipantUnsubscribed.Add(1)
		}
	})

	sm.SubscribeToTrack("track", false)
	s := sm.subscriptions["track"]
	require.True(t, s.isDesired())
	require.Eventually(t, func() bool {
		return subCount.Load() == 1
	}, subSettleTimeout, subCheckInterval, "track was not subscribed")

	require.NotNil(t, s.getSubscribedTrack())
	require.Len(t, sm.GetSubscribedTracks(), 1)

	require.Eventually(t, func() bool {
		return len(sm.GetSubscribedParticipants()) == 1
	}, subSettleTimeout, subCheckInterval, "GetSubscribedParticipants should have returned one item")
	require.Equal(t, "pubID", string(sm.GetSubscribedParticipants()[0]))

	// ensure telemetry events are sent
	tl := sm.params.TelemetryListener.(*typesfakes.FakeParticipantTelemetryListener)
	require.Equal(t, 1, tl.OnTrackSubscribeRequestedCallCount())

	// ensure bound
	setTestSubscribedTrackBound(t, s.getSubscribedTrack())
	require.Eventually(t, func() bool {
		return !s.needsBind()
	}, subSettleTimeout, subCheckInterval, "track was not bound")

	// telemetry event should have been sent
	require.Equal(t, 1, tl.OnTrackSubscribedCallCount())

	// reach subscription limit, subscribe pending
	sm.SubscribeToTrack("track2", false)
	s2 := sm.subscriptions["track2"]
	time.Sleep(subscriptionTimeout * 2)
	require.True(t, s2.needsSubscribe())
	require.Equal(t, 2, tl.OnTrackSubscribeRequestedCallCount())
	require.Equal(t, 1, tl.OnTrackSubscribeFailedCallCount())
	require.Len(t, sm.GetSubscribedTracks(), 1)

	// unsubscribe track1, then track2 should be subscribed
	sm.UnsubscribeFromTrack("track")
	require.False(t, s.isDesired())
	require.True(t, s.needsUnsubscribe())
	// wait for unsubscribe to take effect
	time.Sleep(reconcileInterval)
	setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false)
	require.Nil(t, s.getSubscribedTrack())

	time.Sleep(reconcileInterval)
	require.True(t, s2.isDesired())
	require.False(t, s2.needsSubscribe())
	require.EqualValues(t, 2, subCount.Load())
	require.NotNil(t, s2.getSubscribedTrack())
	require.Equal(t, 2, tl.OnTrackSubscribeRequestedCallCount())
	require.Len(t, sm.GetSubscribedTracks(), 1)

	// ensure bound
	setTestSubscribedTrackBound(t, s2.getSubscribedTrack())
	require.Eventually(t, func() bool {
		return !s2.needsBind()
	}, subSettleTimeout, subCheckInterval, "track was not bound")

	// subscribe to track1 again, which should pending
	sm.SubscribeToTrack("track", false)
	s = sm.subscriptions["track"]
	require.True(t, s.isDesired())
	time.Sleep(subscriptionTimeout * 2)
	require.True(t, s.needsSubscribe())
	require.Equal(t, 3, tl.OnTrackSubscribeRequestedCallCount())
	require.Equal(t, 2, tl.OnTrackSubscribeFailedCallCount())
	require.Len(t, sm.GetSubscribedTracks(), 1)
}

type testSubscriptionParams struct {
	SubscriptionLimitAudio int32
	SubscriptionLimitVideo int32
}

func newTestSubscriptionManager() *SubscriptionManager {
	return newTestSubscriptionManagerWithParams(testSubscriptionParams{})
}

func newTestSubscriptionManagerWithParams(params testSubscriptionParams) *SubscriptionManager {
	p := &typesfakes.FakeLocalParticipant{}
	p.CanSubscribeReturns(true)
	p.IDReturns("subID")
	p.IdentityReturns("sub")
	p.KindReturns(livekit.ParticipantInfo_STANDARD)
	return NewSubscriptionManager(SubscriptionManagerParams{
		Participant:         p,
		Logger:              logger.GetLogger(),
		OnTrackSubscribed:   func(subTrack types.SubscribedTrack) {},
		OnTrackUnsubscribed: func(subTrack types.SubscribedTrack) {},
		OnSubscriptionError: func(trackID livekit.TrackID, fatal bool, err error) {},
		TrackResolver: func(sub types.LocalParticipant, trackID livekit.TrackID) types.MediaResolverResult {
			return types.MediaResolverResult{}
		},
		TelemetryListener:      &typesfakes.FakeParticipantTelemetryListener{},
		SubscriptionLimitAudio: params.SubscriptionLimitAudio,
		SubscriptionLimitVideo: params.SubscriptionLimitVideo,
	})
}

type testResolver struct {
	lock          sync.Mutex
	hasPermission bool
	hasTrack      bool
	pubIdentity   livekit.ParticipantIdentity
	pubID         livekit.ParticipantID

	paused bool
}

func newTestResolver(hasPermission bool, hasTrack bool, pubIdentity livekit.ParticipantIdentity, pubID livekit.ParticipantID) *testResolver {
	return &testResolver{
		hasPermission: hasPermission,
		hasTrack:      hasTrack,
		pubIdentity:   pubIdentity,
		pubID:         pubID,
	}
}

func (t *testResolver) SetPause(paused bool) {
	t.lock.Lock()
	defer t.lock.Unlock()
	t.paused = paused
}

func (t *testResolver) Resolve(_subscriber types.LocalParticipant, trackID livekit.TrackID) types.MediaResolverResult {
	t.lock.Lock()
	defer t.lock.Unlock()
	res := types.MediaResolverResult{
		TrackChangedNotifier: utils.NewChangeNotifier(),
		TrackRemovedNotifier: utils.NewChangeNotifier(),
		HasPermission:        t.hasPermission,
		PublisherID:          t.pubID,
		PublisherIdentity:    t.pubIdentity,
	}
	if t.hasTrack && !t.paused {
		mt := &typesfakes.FakeMediaTrack{}
		st := &typesfakes.FakeSubscribedTrack{}
		st.IDReturns(trackID)
		st.PublisherIDReturns(t.pubID)
		st.PublisherIdentityReturns(t.pubIdentity)
		mt.AddSubscriberCalls(func(sub types.LocalParticipant) (types.SubscribedTrack, error) {
			st.SubscriberReturns(sub)
			return st, nil
		})
		st.MediaTrackReturns(mt)
		res.Track = mt
	}
	return res
}

func setTestSubscribedTrackBound(t *testing.T, st types.SubscribedTrack) {
	fst, ok := st.(*typesfakes.FakeSubscribedTrack)
	require.True(t, ok)

	for i := 0; i < fst.AddOnBindCallCount(); i++ {
		fst.AddOnBindArgsForCall(i)(nil)
	}
}

func setTestSubscribedTrackClosed(t *testing.T, st types.SubscribedTrack, isExpectedToResume bool) {
	fst, ok := st.(*typesfakes.FakeSubscribedTrack)
	require.True(t, ok)

	fst.OnCloseArgsForCall(0)(isExpectedToResume)
}
