# This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
# GitHub: https://github.com/PrateekMunjal
# ----------------------------------------------------------

# code modified from VAAL codebase

import os
import torch
import numpy as np
from tqdm import tqdm

from pycls.models import vaal_model as vm
import pycls.utils.logging as lu
# import pycls.datasets.loader as imagenet_loader

logger = lu.get_logger(__name__)

bce_loss = torch.nn.BCELoss().cuda()


def data_parallel_wrapper(model, cur_device, cfg):
    model.cuda(cur_device)
    model = torch.nn.DataParallel(
        model, device_ids=[i for i in range(torch.cuda.device_count())])
    return model


def distributed_wrapper(cfg, model, cur_device):
    # Transfer the model to the current GPU device
    model = model.cuda(device=cur_device)
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""EfficientNet models."""

import pycls.utils.logging as logging
import pycls.utils.net as nu
import torch
import torch.nn as nn
from pycls.core.config import cfg

logger = logging.get_logger(__name__)


class EffHead(nn.Module):
    """EfficientNet head."""
    def __init__(self, w_in, w_out, nc):
        super(EffHead, self).__init__()
        self._construct(w_in, w_out, nc)

    def _construct(self, w_in, w_out, nc):
        # 1x1, BN, Swish
        self.conv = nn.Conv2d(w_in,
                              w_out,
                              kernel_size=1,
                              stride=1,
                              padding=0,
                              bias=False)