Exemplo n.º 1
0
    def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True):
        super(PNet, self).__init__()

        self.use_gpu = use_gpu

        self.pnet_type = pnet_type
        self.pnet_rand = pnet_rand

        self.shift = to_var(
            torch.Tensor([-.030, -.088, -.188]).view(1, 3, 1, 1))
        self.scale = to_var(torch.Tensor([.458, .448, .450]).view(1, 3, 1, 1))

        if (self.pnet_type in ['vgg', 'vgg16']):
            self.net = pn.vgg16(pretrained=not self.pnet_rand,
                                requires_grad=False)
        elif (self.pnet_type == 'alex'):
            self.net = pn.alexnet(pretrained=not self.pnet_rand,
                                  requires_grad=False)
        elif (self.pnet_type[:-2] == 'resnet'):
            self.net = pn.resnet(pretrained=not self.pnet_rand,
                                 requires_grad=False,
                                 num=int(self.pnet_type[-2:]))
        elif (self.pnet_type == 'squeeze'):
            self.net = pn.squeezenet(pretrained=not self.pnet_rand,
                                     requires_grad=False)

        self.L = self.net.N_slices

        if (use_gpu):
            self.net.cuda()
            self.shift = self.shift.cuda()
            self.scale = self.scale.cuda()
Exemplo n.º 2
0
 def LPIPS(self):
     from misc.utils import compute_lpips
     data_loader = self.data_loader
     n_images = 100
     pair_styles = 20
     model = None
     DISTANCE = {0: [], 1: []}
     self.G.eval()
     for i, (real_x, org_c, files) in tqdm(enumerate(data_loader),
                                           desc='Calculating LPISP',
                                           total=n_images):
         for _real_x, _org_c in zip(real_x, org_c):
             _real_x = _real_x.unsqueeze(0)
             _org_c = _org_c.unsqueeze(0)
             if len(DISTANCE[_org_c[0, 0]]) >= i:
                 continue
             _real_x = to_var(_real_x, volatile=True)
             target_c = to_var(1 - _org_c, volatile=True)
             for _ in range(pair_styles):
                 style0 = to_var(self.G.random_style(_real_x.size(0)),
                                 volatile=True)
                 style1 = to_var(self.G.random_style(_real_x.size(0)),
                                 volatile=True)
                 fake_x0 = self.G(_real_x, target_c, stochastic=style0)
                 fake_x1 = self.G(_real_x, target_c, stochastic=style1)
                 distance, model = compute_lpips(fake_x0,
                                                 fake_x1,
                                                 model=model)
                 DISTANCE[org_c[0, 0]].append(distance)
             if i == len(DISTANCE[0, 0]) == len(DISTANCE[1]):
                 break
     print("LPISP a-b: {}".format(np.array(DISTANCE[0]).mean()))
     print("LPISP b-a: {}".format(np.array(DISTANCE[1]).mean()))
Exemplo n.º 3
0
 def debug(self):
     feed = to_var(torch.ones(1, self.color_dim, self.image_size,
                              self.image_size),
                   volatile=True,
                   no_cuda=True)
     label = to_var(torch.ones(1, self.c_dim), volatile=True, no_cuda=True)
     style = to_var(self.random_style(feed), volatile=True, no_cuda=True)
     self.apply_style(feed, label, style)
     self.generator.debug()
Exemplo n.º 4
0
 def save_multidomain_output(self, real_x, label, save_path, **kwargs):
     self.G.eval()
     self.D.eval()
     no_grad = open('/var/tmp/null.txt',
                    'w') if get_torch_version() < 1.0 else torch.no_grad()
     with no_grad:
         real_x = to_var(real_x, volatile=True)
         n_style = self.config.style_debug
         n_interp = self.config.n_interpolation + 10
         _name = 'domain_interpolation'
         no_label = True
         for idx in range(n_style):
             dirname = save_path.replace('.jpg', '')
             filename = '{}_style{}.jpg'.format(_name,
                                                str(idx + 1).zfill(2))
             _save_path = os.path.join(dirname, filename)
             create_dir(_save_path)
             fake_image_list, fake_attn_list = self.Create_Visual_List(
                 real_x)
             style = self.G.random_style(1).repeat(real_x.size(0), 1)
             style = to_var(style, volatile=True)
             label0 = to_var(label, volatile=True)
             opposite_label = self.target_multiAttr(1 - label,
                                                    2)  # 2: black hair
             opposite_label[:, 7] = 0  # Pale skin
             label1 = to_var(opposite_label, volatile=True)
             labels = [label0, label1]
             styles = [style, style]
             domain_interp = self.MMInterpolation(labels,
                                                  styles,
                                                  n_interp=n_interp)
             for target_de in domain_interp[5:]:
                 # target_de = target_de.repeat(real_x.size(0), 1)
                 target_de = to_var(target_de, volatile=True)
                 fake_x = self.G(real_x, target_de, style, DE=target_de)
                 fake_image_list.append(to_data(fake_x[0], cpu=True))
                 fake_attn_list.append(
                     to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
             self._SAVE_IMAGE(_save_path,
                              fake_image_list,
                              no_label=no_label,
                              arrow=False,
                              circle=False)
             self._SAVE_IMAGE(_save_path,
                              fake_attn_list,
                              Attention=True,
                              arrow=False,
                              no_label=no_label,
                              circle=False)
     self.G.train()
     self.D.train()
Exemplo n.º 5
0
 def debug(self):
     PRINT(self.config.log, '-- Generator:')
     feed = to_var(torch.ones(1, self.color_dim, self.image_size,
                              self.image_size),
                   volatile=True,
                   no_cuda=True)
     features = self.print_debug(feed, self.main)
     self.print_debug(features, self.fake)
     self.print_debug(features, self.attn)
Exemplo n.º 6
0
 def debug(self):
     feed = to_var(
         torch.ones(1, self.color_dim, self.image_size, self.image_size),
         volatile=True,
         no_cuda=True)
     PRINT(self.config.log, '-- StyleEncoder:')
     features = self.print_debug(feed, self.main)
     fc_in = features.view(features.size(0), -1)
     self.print_debug(fc_in, self.fc)
Exemplo n.º 7
0
 def _CLS(self, data):
     data = to_var(data, volatile=True)
     out_label = self.D(data)[1]
     if len(out_label) > 1:
         out_label = torch.cat(
             [F.sigmoid(out.unsqueeze(-1)) for out in out_label],
             dim=-1).mean(dim=-1)
     else:
         out_label = F.sigmoid(out_label[0])
     out_label = (out_label > 0.5).float()
     return out_label
Exemplo n.º 8
0
 def debug(self):
     feed = to_var(torch.ones(1, self.color_dim, self.image_size,
                              self.image_size),
                   volatile=True,
                   no_cuda=True)
     modelList = zip(self.cnns_main, self.cnns_src, self.cnns_aux)
     for idx, outs in enumerate(modelList):
         PRINT(self.config.log, '-- MultiDiscriminator ({}):'.format(idx))
         features = self.print_debug(feed, outs[-3])
         self.print_debug(features, outs[-2])
         self.print_debug(features, outs[-1]).view(feed.size(0), -1)
         feed = self.downsample(feed)
Exemplo n.º 9
0
 def MMInterpolation(self, targets, styles, n_interp=None):
     assert len(targets) == 2 and len(styles) == 2
     if n_interp is None:
         n_interp = self.config.n_interpolation
     in_de0 = self.label2embedding(targets[0], styles[0])
     in_de1 = self.label2embedding(targets[1], styles[1])
     domain_interp = torch.zeros((n_interp, targets[0].size(0),
                                  in_de0.shape[-1]))
     domain_interp = to_var(domain_interp, volatile=True)
     for i in range(targets[0].size(0)):
         domain_interp[:, i] = interpolation(in_de0[i], in_de1[i], n_interp)
     return domain_interp
    def fit(self, configs):
        self.base_model.train()
        dataloader, optimizer = configs['dataloader'], configs['optimizer']
        try:
            flag = True
            total_steps = len(dataloader)
        except:
            flag = False
            total_steps = 1
        current_epoch = configs['current_epoch']
        total_epochs = configs['total_epochs']
        teacher_updates = configs.get('policy_step', -1)
        logger = configs['logger']

        all_correct = 0
        all_samples = 0
        loss_average = 0

        for idx, (inputs, labels) in enumerate(dataloader):
            optimizer.zero_grad()
            if flag:
                inputs = to_var(inputs)
                labels = to_var(labels)
            predicts = self.base_model(inputs)

            eval_res = self.evaluator(predicts, labels)
            num_correct = eval_res['num_correct']
            num_samples = eval_res['num_samples']
            # logger.info('num_samples %d, num_correct %d'%(num_samples, num_correct))
            loss = eval_res['loss']
            all_correct += num_correct
            all_samples += num_samples
            loss.backward()
            optimizer.step()
            logger.info('Policy Steps: [%d] Train: ----- Iteration [%d], loss: %5.4f, accuracy: %5.4f(%5.4f)' % (
                teacher_updates, current_epoch+1, loss.cpu().data[0], num_correct/num_samples, all_correct/all_samples))
            loss_average += loss.cpu().data[0]
        return loss_average/total_steps
Exemplo n.º 11
0
    def Gen_update(self, real_x, real_c, fake_c):
        self.train_model(generator=True)
        real_x, real_c, fake_c = self.to_var(real_x, real_c, fake_c)
        criterion_l1 = torch.nn.L1Loss()
        style_fake = to_var(self.random_style(real_x, seed=self.count_seed))
        style_rec = to_var(self.random_style(real_x, seed=self.count_seed + 1))
        style_identity = to_var(
            self.random_style(real_x, seed=self.count_seed + 2))
        self.count_seed += 3

        fake_x = self.G(real_x, fake_c, style_fake)

        g_loss_src, g_loss_cls = self._GAN_LOSS(fake_x[0], real_x, fake_c)
        self.loss['Gsrc'] = g_loss_src
        self.loss['Gcls'] = g_loss_cls * self.config.lambda_cls

        # REC LOSS
        rec_x = self.G(fake_x[0], real_c, style_rec)
        g_loss_rec = criterion_l1(rec_x[0], real_x)
        self.loss['Grec'] = self.config.lambda_rec * g_loss_rec

        # ========== Attention Part ==========#
        self.loss['Gatm'] = self.config.lambda_mask * (torch.mean(rec_x[1]) +
                                                       torch.mean(fake_x[1]))
        self.loss['Gats'] = self.config.lambda_mask_smooth * (
            _compute_loss_smooth(rec_x[1]) + _compute_loss_smooth(fake_x[1]))

        # ========== Identity Part ==========#
        if self.config.Identity:
            idt_x = self.G(real_x, real_c, style_identity)[0]
            g_loss_idt = criterion_l1(idt_x, real_x)
            self.loss['Gidt'] = self.config.lambda_idt * \
                g_loss_idt

        g_loss = self.current_losses('G', **self.loss)
        self.reset_grad()
        g_loss.backward()
        self.g_optimizer.step()
Exemplo n.º 12
0
    def Dis_update(self, real_x, real_c, fake_c):
        self.train_model(discriminator=True)
        real_x, real_c, fake_c = self.to_var(real_x, real_c, fake_c)
        style_fake = to_var(self.random_style(real_x, seed=self.count_seed))
        self.count_seed += 1
        fake_x = self.G(real_x, fake_c, style_fake)[0]
        d_loss_src, d_loss_cls = self._GAN_LOSS(real_x, fake_x, real_c)

        self.loss['Dsrc'] = d_loss_src
        self.loss['Dcls'] = d_loss_cls * self.config.lambda_cls
        d_loss = self.current_losses('D', **self.loss)
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()
Exemplo n.º 13
0
def val(dataloader, model, val_info, criterion):
    model.eval()
    num_epochs = val_info['num_epochs']
    epoch = val_info['epoch']
    total_steps = len(dataloader)
    total_loss = 0
    for idx, (x_train, x_predict) in enumerate(dataloader):
        x_train = to_var(x_train)
        x_predict = to_var(x_predict)
        data = pack([x_train, x_predict, None],
                    ['x_train', 'x_predict', 'states'])
        configs = pack([False, 10], ['use_gt', 'max_steps'])
        reconstruct, predict = model(data, configs)
        r_loss = criterion(reconstruct, x_train)
        p_loss = criterion(predict, x_predict)
        loss = r_loss + p_loss
        logger.info(
            '[Val] Epoch [%d/%d], Step [%d/%d], Reconstruct Loss: %5.4f, Predict Loss: %5.4f, Total: %5.4f'
            % (epoch, num_epochs, idx + 1, total_steps, r_loss.data[0],
               p_loss.data[0], loss.data[0]))
        total_loss += loss.data[0]

    return total_loss / total_steps
    def val(self, configs):
        self.base_model.eval()
        dataloader = configs['dataloader']
        total_steps = len(dataloader)

        all_correct = 0
        all_samples = 0
        loss_average = 0
        for idx, (inputs, labels) in enumerate(dataloader):
            inputs = to_var(inputs, volatile=True)
            labels = to_var(labels)
            predicts = self.base_model(inputs)
            eval_res = self.evaluator(predicts, labels)
            num_correct = eval_res['num_correct']
            num_samples = eval_res['num_samples']
            all_correct += num_correct
            all_samples += num_samples
            # logger.info('Eval: Epoch [%d/%d], Iteration [%d/%d], accuracy: %5.4f(%5.4f)' % (
            #    current_epoch, total_epochs, idx, total_steps, num_correct/num_samples, all_correct/all_samples))
            loss_average += eval_res['loss'].cpu().data[0]
        # print ('Total: %d, correct: %d', all_samples, all_correct)

        return all_correct/all_samples, loss_average/total_steps
Exemplo n.º 15
0
def train(dataloader, model, optimizer, criterion, train_info):
    model.train()
    num_epochs = train_info['num_epochs']
    epoch = train_info['epoch']
    clip = train_info['clip']
    total_steps = len(dataloader)
    for idx, (x_train, x_predict) in enumerate(dataloader):
        optimizer.zero_grad()
        x_train = to_var(x_train)
        x_predict = to_var(x_predict)
        data = pack([x_train, x_predict, None],
                    ['x_train', 'x_predict', 'states'])
        configs = pack([False, 10], ['use_gt', 'max_steps'])
        reconstruct, predict = model(data, configs)
        r_loss = criterion(reconstruct, x_train)
        p_loss = criterion(predict, x_predict)
        loss = r_loss + p_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        optimizer.step()
        logger.info(
            'Epoch [%d/%d], Step [%d/%d], Reconstruct Loss: %5.4f, Predict Loss: %5.4f, Total: %5.4f'
            % (epoch, num_epochs, idx + 1, total_steps, r_loss.data[0],
               p_loss.data[0], loss.data[0]))
Exemplo n.º 16
0
    def forward(self, in0, in1):
        assert (in0.size()[0] == 1)  # currently only supports batchSize 1

        if (self.colorspace == 'RGB'):
            value = util.dssim(1. * util.tensor2im(in0.data),
                               1. * util.tensor2im(in1.data),
                               range=255.).astype('float')
        elif (self.colorspace == 'Lab'):
            value = util.dssim(
                util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)),
                util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)),
                range=100.).astype('float')
        ret_var = to_var(torch.Tensor((value, )))
        if (self.use_gpu):
            ret_var = ret_var.cuda()
        return ret_var
Exemplo n.º 17
0
    def INCEPTION_REAL(self):
        from misc.utils import load_inception
        from scipy.stats import entropy
        net = load_inception()
        net = to_cuda(net)
        net.eval()
        inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
        mode = 'Real'
        data_loader = self.data_loader
        file_name = 'scores/Inception_{}.txt'.format(mode)

        PRED_IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}
        IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}

        for i, (real_x, org_c, files) in tqdm(
                enumerate(data_loader),
                desc='Calculating CIS/IS - {}'.format(file_name),
                total=len(data_loader)):
            label = torch.max(org_c, 1)[1][0]
            real_x = to_var((real_x + 1) / 2., volatile=True)
            pred = to_data(F.softmax(net(inception_up(real_x)), dim=1),
                           cpu=True).numpy()
            PRED_IS[int(label)].append(pred)

        for label in range(len(data_loader.dataset.labels[0])):
            PRED_IS[label] = np.concatenate(PRED_IS[label], 0)
            # prior is computed from all outputs
            py = np.sum(PRED_IS[label], axis=0)
            for j in range(PRED_IS[label].shape[0]):
                pyx = PRED_IS[label][j, :]
                IS[label].append(entropy(pyx, py))

        total_is = []
        file_ = open(file_name, 'w')
        for label in range(len(data_loader.dataset.labels[0])):
            _is = np.exp(np.mean(IS[label]))
            total_is.append(_is)
            PRINT(file_, "Label {}".format(label))
            PRINT(file_, "Inception Score: {:.4f}".format(_is))
        PRINT(file_, "")
        PRINT(
            file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_is).mean(),
                np.array(total_is).std()))
        file_.close()
Exemplo n.º 18
0
    def forward(self, in0, in1):
        assert (in0.size()[0] == 1)  # currently only supports batchSize 1

        if (self.colorspace == 'RGB'):
            (N, C, X, Y) = in0.size()
            value = torch.mean(torch.mean(torch.mean((in0 - in1)**2,
                                                     dim=1).view(N, 1, X, Y),
                                          dim=2).view(N, 1, 1, Y),
                               dim=3).view(N)
            return value
        elif (self.colorspace == 'Lab'):
            value = util.l2(
                util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)),
                util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)),
                range=100.).astype('float')
            ret_var = to_var(torch.Tensor((value, )))
            if (self.use_gpu):
                ret_var = ret_var.cuda()
            return ret_var
Exemplo n.º 19
0
    def forward(self, data, configs=None):
        _KEYS = ['x', 'states']
        x, states = unpack(data, _KEYS)
        batch_size = states[0][0].size(0)
        if x is None:
            x_c, x_h, x_w = self.cell_config['in_c'], self.cell_config[
                'in_w'], self.cell_config['in_h']
            x = to_var(torch.zeros(batch_size, 1, x_c, x_h, x_w))
        # x: batch_size, time_steps, channels, height, width
        time_steps = x.size(1)
        next_states = []
        cell_list = self.cell_list
        current_input = [x[:, t] for t in xrange(time_steps)]
        for l in xrange(self.num_layers):
            h0, c0 = states[l]
            for t in xrange(time_steps):
                data = pack([current_input[t], (h0, c0)], ['x', 'states'])
                h, c = cell_list[l](data)
                next_states.append((h, c))
                current_input[t] = h
                states[l] = (h, c)

        return states
Exemplo n.º 20
0
    def Modality(self, target, style, Multimodality, idx=0):
        _size = target.size(0)
        if self.config.dataset_fake in self.MultiLabel_Datasets:
            target = (self.org_label - target)**2  # Swap labels
            target = self.target_multiAttr(target, idx)
            target = to_var(target, volatile=True)

        if Multimodality == 1:
            # Random Styles
            domain_embedding = self.label2embedding(target, style, _torch=True)

        elif Multimodality == 2:
            # Style interpolation | Fixed Labels
            # The batch belongs to the same image
            style0 = style[0].repeat(_size, 1)
            style1 = style[1].repeat(_size, 1)
            targets = [target, target]
            styles = [style0, style1]
            domain_embedding = self.MMInterpolation(targets, styles)[:, 0]

        elif Multimodality == 3:
            # Style constant | Progressive swap label
            n_interp = self.config.n_interpolation + 5
            target0 = self.org_label
            target1 = target
            style = style[0].repeat(_size, 1)
            targets = [target0, target1]
            styles = [style, style]
            domain_embedding = self.MMInterpolation(targets, styles,
                                                    n_interp)[5:, 0]

        else:
            # Unimodal
            style = style[0].repeat(_size, 1)
            domain_embedding = self.label2embedding(target, style, _torch=True)

        return domain_embedding
Exemplo n.º 21
0
 def to_var(self, *args):
     vars = []
     for arg in args:
         vars.append(to_var(arg))
     return vars
Exemplo n.º 22
0
 def init_hidden(self, batch_size, cuda=False):
     return (to_var((torch.zeros(batch_size, self.h_c, self.in_h,
                                 self.in_w))),
             to_var(torch.zeros(batch_size, self.h_c, self.in_h,
                                self.in_w)))
Exemplo n.º 23
0
    def forward(self, in0, in1, retNumpy=True):
        ''' Function computes the distance between image patches in0 and in1
        INPUTS
            in0, in1 - torch.Tensor object of shape Nx3xXxY - i
                mage patch scaled to [-1,1]
            retNumpy - [False] to return as torch.Tensor,
                [True] to return as numpy array
        OUTPUT
            computed distances between in0 and in1
        '''

        self.input_ref = in0
        self.input_p0 = in1

        if (self.use_gpu):
            self.input_ref = self.input_ref.cuda()
            self.input_p0 = self.input_p0.cuda()

        self.var_ref = to_var(self.input_ref, requires_grad=True)
        self.var_p0 = to_var(self.input_p0, requires_grad=True)

        self.d0 = self.forward_pair(self.var_ref, self.var_p0)
        self.loss_total = self.d0

        def convert_output(d0):
            if (retNumpy):
                ans = d0.cpu().data.numpy()
                if not self.spatial:
                    ans = ans.flatten()
                else:
                    assert (ans.shape[0] == 1 and len(ans.shape) == 4)
                    # Reshape to usual numpy image format: (height, width,
                    # channels)
                    return ans[0, ...].transpose([1, 2, 0])
                return ans
            else:
                return d0

        if self.spatial:
            L = [convert_output(x) for x in self.d0]
            spatial_shape = self.spatial_shape
            if spatial_shape is None:
                if (self.spatial_factor is None):
                    spatial_shape = (in0.size()[2], in0.size()[3])
                else:
                    spatial_shape = (max([x.shape[0]
                                          for x in L]) * self.spatial_factor,
                                     max([x.shape[1]
                                          for x in L]) * self.spatial_factor)

            L = [
                skimage.transform.resize(x,
                                         spatial_shape,
                                         order=self.spatial_order,
                                         mode='edge') for x in L
            ]

            L = np.mean(np.concatenate(L, 2) * len(L), 2)
            return L
        else:
            return convert_output(self.d0)
Exemplo n.º 24
0
def train_inception(batch_size, shuffling=False, num_workers=4, **kwargs):

    from torchvision.models import inception_v3
    from misc.utils import to_var, to_cuda, to_data
    from torchvision import transforms
    from torch.utils.data import DataLoader
    import torch.nn.functional as F
    import torch
    import torch.nn as nn
    import tqdm

    metadata_path = os.path.join('data', 'RafD', 'normal')
    # inception Norm

    image_size = 299

    transform = []
    window = int(image_size / 10)
    transform += [
        transforms.Resize((image_size + window, image_size + window),
                          interpolation=Image.ANTIALIAS)
    ]
    transform += [
        transforms.RandomResizedCrop(image_size,
                                     scale=(0.7, 1.0),
                                     ratio=(0.8, 1.2))
    ]
    transform += [transforms.RandomHorizontalFlip()]
    transform += [transforms.ToTensor()]
    transform = transforms.Compose(transform)

    dataset_train = RafD(image_size,
                         metadata_path,
                         transform,
                         'train',
                         shuffling=True,
                         **kwargs)
    dataset_test = RafD(image_size,
                        metadata_path,
                        transform,
                        'test',
                        shuffling=False,
                        **kwargs)

    train_loader = DataLoader(dataset=dataset_train,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)
    test_loader = DataLoader(dataset=dataset_test,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)

    num_labels = len(train_loader.dataset.labels[0])
    n_epochs = 100
    net = inception_v3(pretrained=True, transform_input=True)
    net.aux_logits = False
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, num_labels)

    net_save = metadata_path + '/inception_v3/{}.pth'
    if not os.path.isdir(os.path.dirname(net_save)):
        os.makedirs(os.path.dirname(net_save))
    print("Model will be saved at: " + net_save)
    optimizer = torch.optim.RMSprop(net.parameters(), lr=1e-5)
    # loss = F.cross_entropy(output, target)
    to_cuda(net)

    for epoch in range(n_epochs):
        LOSS = {'train': [], 'test': []}
        OUTPUT = {'train': [], 'test': []}
        LABEL = {'train': [], 'test': []}

        net.eval()
        for i, (data, label,
                files) in tqdm.tqdm(enumerate(test_loader),
                                    total=len(test_loader),
                                    desc='Validating Inception V3 | RafD'):
            data = to_var(data, volatile=True)
            label = to_var(torch.max(label, dim=1)[1], volatile=True)
            out = net(data)
            loss = F.cross_entropy(out, label)
            # ipdb.set_trace()
            LOSS['test'].append(to_data(loss, cpu=True)[0])
            OUTPUT['test'].extend(
                to_data(F.softmax(out, dim=1).max(1)[1], cpu=True).tolist())
            LABEL['test'].extend(to_data(label, cpu=True).tolist())
        acc_test = (np.array(OUTPUT['test']) == np.array(LABEL['test'])).mean()
        print('[Test] Loss: {:.4f} Acc: {:.4f}'.format(
            np.array(LOSS['test']).mean(), acc_test))

        net.train()
        for i, (data, label, files) in tqdm.tqdm(
                enumerate(train_loader),
                total=len(train_loader),
                desc='[{}/{}] Train Inception V3 | RafD'.format(
                    str(epoch).zfill(5),
                    str(n_epochs).zfill(5))):
            # ipdb.set_trace()
            data = to_var(data)
            label = to_var(torch.max(label, dim=1)[1])
            out = net(data)
            # ipdb.set_trace()
            loss = F.cross_entropy(out, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            LOSS['train'].append(to_data(loss, cpu=True)[0])
            OUTPUT['train'].extend(
                to_data(F.softmax(out, dim=1).max(1)[1], cpu=True).tolist())
            LABEL['train'].extend(to_data(label, cpu=True).tolist())

        acc_train = (np.array(OUTPUT['train']) == np.array(
            LABEL['train'])).mean()
        print('[Train] Loss: {:.4f} Acc: {:.4f}'.format(
            np.array(LOSS['train']).mean(), acc_train))
        torch.save(net.state_dict(), net_save.format(str(epoch).zfill(5)))
        train_loader.dataset.shuffle(epoch)
Exemplo n.º 25
0
    def INCEPTION(self):
        from misc.utils import load_inception
        from scipy.stats import entropy
        n_styles = 20
        net = load_inception()
        net = to_cuda(net)
        net.eval()
        self.G.eval()
        inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
        mode = 'SMIT'
        data_loader = self.data_loader
        file_name = 'scores/Inception_{}.txt'.format(mode)

        PRED_IS = {i: []
                   for i in range(len(data_loader.dataset.labels[0]))
                   }  # 0:[], 1:[], 2:[]}
        CIS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}
        IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}

        for i, (real_x, org_c, files) in tqdm(
                enumerate(data_loader),
                desc='Calculating CIS/IS - {}'.format(file_name),
                total=len(data_loader)):
            PRED_CIS = {
                i: []
                for i in range(len(data_loader.dataset.labels[0]))
            }  # 0:[], 1:[], 2:[]}
            org_label = torch.max(org_c, 1)[1][0]
            real_x = real_x.repeat(n_styles, 1, 1, 1)  # .unsqueeze(0)
            real_x = to_var(real_x, volatile=True)

            target_c = (org_c * 0).repeat(n_styles, 1)
            target_c = to_var(target_c, volatile=True)
            for label in range(len(data_loader.dataset.labels[0])):
                if org_label == label:
                    continue
                target_c *= 0
                target_c[:, label] = 1
                style = to_var(self.G.random_style(n_styles),
                               volatile=True) if mode == 'SMIT' else None

                fake = (self.G(real_x, target_c, style)[0] + 1) / 2

                pred = to_data(F.softmax(net(inception_up(fake)), dim=1),
                               cpu=True).numpy()
                PRED_CIS[label].append(pred)
                PRED_IS[label].append(pred)

                # CIS for each image
                PRED_CIS[label] = np.concatenate(PRED_CIS[label], 0)
                py = np.sum(
                    PRED_CIS[label], axis=0
                )  # prior is computed from outputs given a specific input
                for j in range(PRED_CIS[label].shape[0]):
                    pyx = PRED_CIS[label][j, :]
                    CIS[label].append(entropy(pyx, py))

        for label in range(len(data_loader.dataset.labels[0])):
            PRED_IS[label] = np.concatenate(PRED_IS[label], 0)
            py = np.sum(PRED_IS[label],
                        axis=0)  # prior is computed from all outputs
            for j in range(PRED_IS[label].shape[0]):
                pyx = PRED_IS[label][j, :]
                IS[label].append(entropy(pyx, py))

        total_cis = []
        total_is = []
        file_ = open(file_name, 'w')
        for label in range(len(data_loader.dataset.labels[0])):
            cis = np.exp(np.mean(CIS[label]))
            total_cis.append(cis)
            _is = np.exp(np.mean(IS[label]))
            total_is.append(_is)
            PRINT(file_, "Label {}".format(label))
            PRINT(file_, "Inception Score: {:.4f}".format(_is))
            PRINT(file_, "conditional Inception Score: {:.4f}".format(cis))
        PRINT(file_, "")
        PRINT(
            file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_is).mean(),
                np.array(total_is).std()))
        PRINT(
            file_,
            "[TOTAL] conditional Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_cis).mean(),
                np.array(total_cis).std()))
        file_.close()
Exemplo n.º 26
0
    def __init__(self,
                 pnet_type='vgg',
                 pnet_rand=False,
                 pnet_tune=False,
                 use_dropout=True,
                 use_gpu=True,
                 spatial=False,
                 version='0.1'):
        super(PNetLin, self).__init__()

        self.use_gpu = use_gpu
        self.pnet_type = pnet_type
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.version = version

        if (self.pnet_type in ['vgg', 'vgg16']):
            net_type = pn.vgg16
            self.chns = [64, 128, 256, 512, 512]
        elif (self.pnet_type == 'alex'):
            net_type = pn.alexnet
            self.chns = [64, 192, 384, 256, 256]
        elif (self.pnet_type == 'squeeze'):
            net_type = pn.squeezenet
            self.chns = [64, 128, 256, 384, 384, 512, 512]

        if (self.pnet_tune):
            self.net = net_type(pretrained=not self.pnet_rand,
                                requires_grad=True)
        else:
            self.net = [
                net_type(pretrained=not self.pnet_rand, requires_grad=False),
            ]

        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
        self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        if (self.pnet_type == 'squeeze'):  # 7 layers for squeezenet
            self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
            self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
            self.lins += [self.lin5, self.lin6]

        self.shift = to_var(
            torch.Tensor([-.030, -.088, -.188]).view(1, 3, 1, 1))
        self.scale = to_var(torch.Tensor([.458, .448, .450]).view(1, 3, 1, 1))

        if (use_gpu):
            if (self.pnet_tune):
                self.net.cuda()
            else:
                self.net[0].cuda()
            self.shift = self.shift.cuda()
            self.scale = self.scale.cuda()
            self.lin0.cuda()
            self.lin1.cuda()
            self.lin2.cuda()
            self.lin3.cuda()
            self.lin4.cuda()
            if (self.pnet_type == 'squeeze'):
                self.lin5.cuda()
                self.lin6.cuda()
Exemplo n.º 27
0
    def save_multimodal_output(self,
                               real_x,
                               label,
                               save_path,
                               interpolation=False,
                               **kwargs):
        self.G.eval()
        self.D.eval()
        n_rep = 4
        no_label = self.config.dataset_fake in self.Binary_Datasets
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            real_x = to_var(real_x, volatile=True)
            out_label = to_var(label, volatile=True)
            # target_c_list = [out_label] * 7

            for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)):
                _name = 'multimodal'
                if interpolation == 1:
                    _name += '_interp'
                elif interpolation == 2:
                    _name = 'multidomain_interp'

                _save_path = os.path.join(
                    save_path.replace('.jpg', ''),
                    '{}_{}.jpg'.format(_name,
                                       str(idx).zfill(4)))
                create_dir(_save_path)
                real_x0 = real_x0.repeat(n_rep, 1, 1, 1)
                real_c0 = real_c0.repeat(n_rep, 1)

                fake_image_list, fake_attn_list = self.Create_Visual_List(
                    real_x0, Multimodal=True)

                target_c_list = [real_c0] * 7
                for _, target_c in enumerate(target_c_list):
                    if interpolation == 0:
                        style_ = to_var(self.G.random_style(n_rep),
                                        volatile=True)
                        embeddings = self.label2embedding(target_c,
                                                          style_,
                                                          _torch=True)
                    elif interpolation == 1:
                        style_ = to_var(self.G.random_style(1), volatile=True)
                        style1 = to_var(self.G.random_style(1), volatile=True)
                        _target_c = target_c[0].unsqueeze(0)
                        styles = [style_, style1]
                        targets = [_target_c, _target_c]
                        embeddings = self.MMInterpolation(targets,
                                                          styles,
                                                          n_interp=n_rep)[:, 0]
                    elif interpolation == 2:
                        style_ = to_var(self.G.random_style(1), volatile=True)
                        target0 = 1 - target_c[0].unsqueeze(0)
                        target1 = target_c[0].unsqueeze(0)
                        styles = [style_, style_]
                        targets = [target0, target1]
                        # import ipdb; ipdb.set_trace()
                        embeddings = self.MMInterpolation(targets,
                                                          styles,
                                                          n_interp=n_rep)[:, 0]
                    else:
                        raise ValueError(
                            "There are only 2 types of interpolation:\
                            Multimodal and Multi-domain")
                    fake_x = self.G(real_x0, target_c, style_, DE=embeddings)
                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
                self._SAVE_IMAGE(_save_path,
                                 fake_image_list,
                                 mode='style_' + chr(65 + idx),
                                 no_label=no_label,
                                 arrow=interpolation,
                                 circle=False)
                self._SAVE_IMAGE(_save_path,
                                 fake_attn_list,
                                 Attention=True,
                                 mode='style_' + chr(65 + idx),
                                 arrow=interpolation,
                                 no_label=no_label,
                                 circle=False)
        self.G.train()
        self.D.train()
Exemplo n.º 28
0
    def save_multimodal_output(self,
                               real_x,
                               label,
                               save_path,
                               interpolation=False,
                               **kwargs):
        self.G.eval()
        self.D.eval()
        n_rep = 4
        no_label = self.config.dataset_fake in self.Binary_Datasets
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            real_x = to_var(real_x, volatile=True)
            out_label = to_var(label, volatile=True)
            # target_c_list = [out_label] * 7

            for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)):
                _name = 'multimodal'
                if interpolation:
                    _name = _name + '_interp'
                _save_path = os.path.join(
                    save_path.replace('.jpg', ''), '{}_{}.jpg'.format(
                        _name,
                        str(idx).zfill(4)))
                create_dir(_save_path)
                real_x0 = real_x0.repeat(n_rep, 1, 1, 1)
                real_c0 = real_c0.repeat(n_rep, 1, 1, 1)

                fake_image_list = [
                    to_data(
                        color_frame(
                            single_source(real_x0),
                            thick=5,
                            color='green',
                            first=True),
                        cpu=True)
                ]
                fake_attn_list = [
                    to_data(
                        color_frame(
                            single_source(real_x0),
                            thick=5,
                            color='green',
                            first=True),
                        cpu=True)
                ]

                target_c_list = [real_c0] * 7
                for _, target_c in enumerate(target_c_list):
                    # target_c = _target_c#[0].repeat(n_rep, 1)
                    if not interpolation:
                        style_ = self.G.random_style(n_rep)
                    else:
                        z0 = to_data(
                            self.G.random_style(1), cpu=True).numpy()[0]
                        z1 = to_data(
                            self.G.random_style(1), cpu=True).numpy()[0]
                        style_ = self.G.random_style(n_rep)
                        style_[:] = torch.FloatTensor(
                            np.array([
                                slerp(sz, z0, z1)
                                for sz in np.linspace(0, 1, n_rep)
                            ]))
                    style = to_var(style_, volatile=True)
                    fake_x = self.G(real_x0, target_c, stochastic=style)
                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
                self._SAVE_IMAGE(
                    _save_path,
                    fake_image_list,
                    mode='style_' + chr(65 + idx),
                    no_label=no_label,
                    arrow=interpolation,
                    circle=False)
                self._SAVE_IMAGE(
                    _save_path,
                    fake_attn_list,
                    Attention=True,
                    mode='style_' + chr(65 + idx),
                    arrow=interpolation,
                    no_label=no_label,
                    circle=False)
        self.G.train()
        self.D.train()
Exemplo n.º 29
0
    def val_teacher(self, configs):
        # TODO: test for the policy. Plotting the curve of #effective_samples-test_accuracy
        '''
        :param configs:
            Required:
                state_func
                dataloader: student/dev/test
                optimizer: student
                lr_scheduler: student
                logger
            Optional:
                threshold
                M
                num_classes
                max_t
                (Note: should be consistent with training)
        :return:
        '''
        teacher = self.teacher_net
        # ==================== train student from scratch ============
        init_params(self.student_net)
        student = self.student_net
        # ==================== fetch configs [optional] ===============
        threshold = configs.get('threshold', 0.5)
        M = configs.get('M', 128)
        num_classes = configs.get('num_classes', 10)
        max_t = configs.get('max_t', 50000)
        # =================== fetch configs [required] ================
        state_func = configs['state_func']
        student_dataloader = configs['dataloader']['student']
        dev_dataloader = configs['dataloader']['dev']
        test_dataloader = configs['dataloader']['test']
        student_optimizer = configs['optimizer']['student']
        student_lr_scheduler = configs['lr_scheduler']['student']
        logger = configs['logger']

        # ================== init tracking history ====================
        training_loss_history = []
        val_loss_history = []

        student_updates = 0
        best_acc_on_dev = 0
        best_acc_on_test = 0
        i_tau = 0
        effective_num = 0
        effnum_acc_curves = []

        while i_tau < max_t:
            i_tau += 1
            count = 0
            input_pool = []
            label_pool = []
            # ================== collect training batch ============
            for idx, (inputs, labels) in enumerate(student_dataloader):
                inputs = to_var(inputs)
                labels = to_var(labels)
                state_configs = {
                    'num_classes': num_classes,
                    'labels': labels,
                    'inputs': inputs,
                    'student': student,
                    'current_iter': i_tau,
                    'max_iter': max_t,
                    'train_loss_history': training_loss_history,
                    'val_loss_history': val_loss_history
                }
                states = state_func(
                    state_configs
                )  # TODO: implement the function for computing state
                _inputs = {'input': states}
                predicts = teacher(_inputs, None)

                indices = torch.nonzero(predicts.data.squeeze() >= threshold)
                if len(indices) == 0:
                    continue
                count += len(indices)
                # selected_inputs = torch.gather(inputs, 0, indices.squeeze()).view(len(indices),
                #                                                                  *inputs.size()[1:])
                # selected_labels = torch.gather(labels, 0, indices.squeeze()).view(-1, 1)
                # import pdb
                # pdb.set_trace()
                selected_inputs = inputs[indices.squeeze()].view(
                    len(indices),
                    *inputs.size()[1:])
                selected_labels = labels[indices.squeeze()].view(-1, 1)
                input_pool.append(selected_inputs)
                label_pool.append(selected_labels)
                if count >= M:
                    effective_num += count
                    break

            # ================== prepare training data =============
            inputs = torch.cat(input_pool, 0)
            labels = torch.cat(label_pool, 0)
            st_configs = {
                'dataloader': to_generator([inputs, labels]),
                'optimizer': student_optimizer,
                'current_epoch': student_updates,
                'total_epochs': -1,
                'logger': logger
            }
            # ================= feed the selected batch ============
            train_loss = student.fit(st_configs)
            training_loss_history.append(train_loss)
            student_updates += 1
            student_lr_scheduler(student_optimizer, student_updates)

            # ================ test on dev set =====================
            st_configs['dataloader'] = dev_dataloader
            acc, val_loss = student.val(st_configs)
            best_acc_on_dev = acc if best_acc_on_dev < acc else best_acc_on_dev
            logger.info(
                'Test on Dev: Iteration [%d], accuracy: %5.4f, best: %5.4f' %
                (student_updates, acc, best_acc_on_dev))
            val_loss_history.append(val_loss)

            # =============== test on test set ======================
            st_configs['dataloader'] = test_dataloader
            acc, test_loss = student.val(st_configs)
            best_acc_on_test = acc if best_acc_on_test < acc else best_acc_on_test
            logger.info(
                'Testing Set: Iteration [%d], accuracy: %5.4f, best: %5.4f' %
                (student_updates, acc, best_acc_on_test))
            effnum_acc_curves.append((effective_num, acc))
        return effnum_acc_curves
Exemplo n.º 30
0
    def fit_teacher(self, configs):
        '''
        :param configs:
            Required:
                state_func: [function] used to compute the state vector

                dataloader: [dict]
                    teacher: teacher training data loader
                    student: student training data loader
                    dev: for testing the student model so as to compute reward for the teacher
                    test: student testing data loader

                optimizer: [dict]
                    teacher: the optimizer for teacher
                    student: the optimizer for student

                lr_scheduler: [dict]
                    teahcer: the learning rate scheduler for the teacher model
                    student: the learning rate scheduler for the student model

                <del>current_epoch: [int] the current epoch</del>
                <del>total_epochs: the max number of epochs to train the model</del>
                logger: the logger

            Optional:
                max_t: [int] [50,000]
                    the maximum number iterations before stopping the teaching
                    , and once reach this number, return a reward 0.
                tau: [float32] [0.8]
                    the expected accuracy of the student model on dev set
                threshold: [float32] [0.5]
                    the probability threshold for choosing a sample.
                M: [int] [128]
                    the required batch-size for training the student model.
                max_non_increasing_steps: [int] [10]
                    The maximum number of iterations of the reward not increasing.
                    If exceeds it, stop training the teacher model.
                num_classes: [int] [10]
                    the number of classes in the training set.
        :return:
        '''
        teacher = self.teacher_net
        student = self.student_net
        # ==================== fetch configs [optional] ===============
        max_t = configs['max_t']
        tau = configs['tau']
        M = configs['M']
        max_non_increasing_steps = configs['max_non_increasing_steps']
        num_classes = configs['num_classes']

        # =================== fetch configs [required] ================
        state_func = configs['state_func']
        teacher_dataloader = configs['dataloader']['teacher']
        dev_dataloader = configs['dataloader']['dev']
        teacher_optimizer = configs['optimizer']['teacher']
        student_optimizer = configs['optimizer']['student']
        teacher_lr_scheduler = configs['lr_scheduler']['teacher']
        student_lr_scheduler = configs['lr_scheduler']['student']
        logger = configs['logger']

        # ================== init tracking history ====================
        rewards = []
        training_loss_history = []
        val_loss_history = []
        num_steps_to_achieve = []

        non_increasing_steps = 0
        student_updates = 0
        teacher_updates = 0
        best_acc_on_dev = 0
        while True:
            i_tau = 0
            actions = []

            def overloaded_init_params(x):
                init_params(x)
                # if pointer == 0:
                #    init_params(x)
                # else:
                #     file_name = './model/resnet34-%5.4f.pth.tar' % (tau_list[pointer - 1])
                #     logger.info('Loaded model from' + file_name)
                #     x.load_state_dict(torch.load(file_name)['state_dict'])

            while i_tau < max_t:
                i_tau += 1
                count = 0
                input_pool = []
                label_pool = []
                # ================== collect training batch ============
                while True:
                    for idx, (inputs, labels) in enumerate(teacher_dataloader):
                        inputs = to_var(inputs)
                        labels = to_var(labels)
                        state_configs = {
                            'num_classes': num_classes,
                            'labels': labels,
                            'inputs': inputs,
                            'student': student.train(),
                            'current_iter': i_tau,
                            'max_iter': max_t,
                            'train_loss_history': training_loss_history,
                            'val_loss_history': val_loss_history
                        }
                        states = state_func(
                            state_configs
                        )  # TODO: implement the function for computing state
                        _inputs = {'input': states.detach()}
                        predicts = teacher(_inputs, None)
                        sampled_actions = torch.bernoulli(
                            predicts.data.squeeze())
                        indices = torch.nonzero(sampled_actions)

                        if len(indices) == 0:
                            #print (predicts.data.squeeze())
                            continue
                        # print ('Selected %d/%d samples'%(len(indices), len(labels)))
                        count += len(indices)
                        selected_inputs = inputs[indices.squeeze()].view(
                            len(indices),
                            *inputs.size()[1:])
                        selected_labels = labels[indices.squeeze()].view(-1, 1)
                        input_pool.append(selected_inputs)
                        label_pool.append(selected_labels)
                        actions.append(
                            torch.log(predicts.squeeze()) *
                            to_var(sampled_actions - 0.5) * 2)
                        if count >= M:
                            break
                    if count >= M:
                        break

                # ================== prepare training data =============
                inputs = torch.cat(input_pool, 0)
                labels = torch.cat(label_pool, 0)
                st_configs = {
                    'dataloader': to_generator([inputs, labels]),
                    'optimizer': student_optimizer,
                    'current_epoch': student_updates,
                    'total_epochs': 0,
                    'logger': logger,
                    'policy_step': teacher_updates
                }
                # ================= feed the selected batch ============
                train_loss = student.fit(st_configs)
                training_loss_history.append(train_loss)
                student_updates += 1
                student_lr_scheduler(student_optimizer, student_updates)
                # ================ test on dev set =====================
                st_configs['dataloader'] = dev_dataloader
                acc, val_loss = student.val(st_configs)
                best_acc_on_dev = acc if best_acc_on_dev < acc else best_acc_on_dev
                logger.info(
                    'Stage [%d], Policy Steps: [%d] Test on Dev: Iteration [%d], accuracy: %5.4f, best: %5.4f, '
                    'loss: %5.4f' % (0, teacher_updates, student_updates, acc,
                                     best_acc_on_dev, val_loss))
                val_loss_history.append(val_loss)
                # ============== check if reach the expected accuracy or exceeds the max_t ==================
                if acc >= tau or i_tau == max_t:
                    num_steps_to_achieve.append(i_tau)
                    teacher_optimizer.zero_grad()

                    reward = -math.log(i_tau / max_t)
                    baseline = 0 if len(
                        rewards) == 0 else 0.8 * baseline + 0.2 * reward
                    last_reward = 0 if len(rewards) == 0 else rewards[-1]

                    if last_reward >= reward:
                        non_increasing_steps += 1
                    else:
                        non_increasing_steps = 0

                    loss = -sum([torch.sum(_)
                                 for _ in actions]) * (reward - baseline)
                    print('=' * 80)
                    print(actions[0])
                    print('=' * 80)
                    logger.info(
                        'Policy: Iterations [%d], stops at %d/%d to achieve %5.4f, loss: %5.4f, '
                        'reward: %5.4f(%5.4f)' %
                        (teacher_updates, i_tau, max_t, acc,
                         loss.cpu().data[0], reward, baseline))
                    rewards.append(reward)
                    loss.backward()
                    teacher_optimizer.step()
                    for name, param in teacher.named_parameters():
                        print(name, param)
                    teacher_updates += 1
                    teacher_lr_scheduler(teacher_optimizer, teacher_updates)

                    # ========= reinitialize the student network =========
                    overloaded_init_params(self.student_net)
                    student_updates = 0
                    best_acc_on_dev = 0
                    print('Initialized the student net\'s parameters')
                    # ========== break for next batch ====================
                    break

            # ==================== policy converged (stopping criteria) ==
            if non_increasing_steps >= max_non_increasing_steps:
                torch.save({'num_steps_to_achieve': num_steps_to_achieve},
                           './tmp/curve_stage_%d.pth.tar' % 0)
                print(num_steps_to_achieve)
                return num_steps_to_achieve