Home

June 30, 2017, 3 min read

DRF - Get Size of List Response

I am starting to really dig the Django Rest Framework, because it is rather flexible to extend with custom behaviour. Building an API usually involves routing endpoints to resources to create, list, retrieve, delete and update them. Sometimes it's useful to know how many objects the API will give us before actually querying that data, because we don't want to fetch 2 MB of users just to get that number. DRF makes it trivial to add a query parameter flag to instead just return the number of items in the list we would get. Let's say we have a

User

model being exposed via the API. We want

/api/v1/users

to give us the list of all users, but

/api/v1/users?count

to simply return the number of users instead.

class Countable(object):
    def list(self, request, *args, **kwargs):
        if "count" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            return Response(queryset.count())
        return super(Countable, self).list(request, *args, **kwargs)

class UserViewSet(mixins.Countable, viewsets.ModelViewSet):
    queryset = User.objects.all()
    serializer_class = UserSerializer

Done! This also supports combining the

count

with other query parameters (e.g. for filtering) like

/api/v1/users?first_name=John&count

. Addendum: It is also very simple to add other custom behaviour, like receiving the first or last item this way:

def list(self, request, *args, **kwargs):
        if "first" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            data = self.get_serializer_class()(queryset.first()).data
            return Response(data)
    
        if "last" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            data = self.get_serializer_class()(queryset.last()).data
            return Response(data)

        return super(Countable, self).list(request, *args, **kwargs)

Note: If your serializers rely on things passed through the context like the

request

object (useful e.g. to check user permissions), make sure to pass those as well when manually creating serializers:

context = {"request": request}
        if "last" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            data = self.get_serializer_class()(queryset.last(), 
                                               context=context).data

Here is a version that handles both QuerySet and other iterables like list that have no support for

.first()

etc.

class Countable(object):

    def list(self, request, *args, **kwargs):
        if "count" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            return Response(self._count(queryset))

        context = {"request": request}

        if "first" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            data = self.get_serializer_class()(self._first(queryset),
                                               context=context).data
            return Response(data)

        if "last" in request.query_params:
            queryset = self.filter_queryset(self.get_queryset())
            data = self.get_serializer_class()(self._last(queryset),
                                               context=context).data
            return Response(data)

        return super(Countable, self).list(request, *args, **kwargs)

    def _first(self, iterable):
        """Support getting first item from both a QuerySet or a list."""
        if iterable.__class__ == QuerySet:
            return iterable.first()
        return iterable[0]

    def _last(self, iterable):
        """Support getting last item from both a QuerySet or a list."""
        if iterable.__class__ == QuerySet:
            return iterable.last()
        return iterable[-1]

    def _count(self, iterable):
        """Support getting length from both a QuerySet or a list."""
        if iterable.__class__ == QuerySet:
            return iterable.count()
        return len(iterable)