def as_cutmix(input, target, conf, model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    if r < conf.prob:
        bs = input.size(0)
        lam = np.random.beta(conf.beta, conf.beta)
        rand_index = torch.randperm(bs).cuda()
        target_b = target[rand_index]

        bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = utils.rand_bbox(input.size(), lam)

        if (bby2_1 - bby1_1) * (bbx2_1 - bbx1_1) > 4 and (bby2 - bby1) * (
                bbx2 - bbx1) > 4:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = F.interpolate(ncont,
                                  size=(bbx2 - bbx1, bby2 - bby1),
                                  mode='bilinear',
                                  align_corners=True)
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            # adjust lambda to exactly match pixel ratio
            lam_a = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                         (input.size()[-1] * input.size()[-2]))
            lam_a *= torch.ones(input.size(0))
    lam_b = 1 - lam_a

    return input, target, target_b, lam_a.cuda(), lam_b.cuda()
示例#2
0
def snapmix(input, target, conf, model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    if r < conf.prob:
        wfmaps, _ = get_spm(input, target, conf, model)
        bs = input.size(0)
        lam = np.random.beta(conf.beta, conf.beta)
        lam1 = np.random.beta(conf.beta, conf.beta)
        rand_index = torch.randperm(bs).cuda()
        wfmaps_b = wfmaps[rand_index, :, :]
        target_b = target[rand_index]

        same_label = target == target_b
        bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = utils.rand_bbox(input.size(), lam1)

        area = (bby2 - bby1) * (bbx2 - bbx1)
        area1 = (bby2_1 - bby1_1) * (bbx2_1 - bbx1_1)

        if area1 > 0 and area > 0:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = F.interpolate(ncont,
                                  size=(bbx2 - bbx1, bby2 - bby1),
                                  mode='bilinear',
                                  align_corners=True)
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            lam_a = 1 - wfmaps[:, bbx1:bbx2, bby1:bby2].sum(2).sum(1) / (
                wfmaps.sum(2).sum(1) + 1e-8)
            lam_b = wfmaps_b[:, bbx1_1:bbx2_1, bby1_1:bby2_1].sum(2).sum(1) / (
                wfmaps_b.sum(2).sum(1) + 1e-8)
            tmp = lam_a.clone()
            lam_a[same_label] += lam_b[same_label]
            lam_b[same_label] += tmp[same_label]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                       (input.size()[-1] * input.size()[-2]))
            lam_a[torch.isnan(lam_a)] = lam
            lam_b[torch.isnan(lam_b)] = 1 - lam

    return input, target, target_b, lam_a.cuda(), lam_b.cuda()
示例#3
0
def cutmix(data, labels, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_labels = labels[indices]

    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
               (data.size()[-1] * data.size()[-2]))

    return data, shuffled_labels, lam
def cutout(input, target, conf=None, model=None):

    r = np.random.rand(1)
    lam = torch.ones(input.size(0)).cuda()
    target_b = target.clone()
    lam_a = lam
    lam_b = 1 - lam

    if r < conf.prob:
        bs = input.size(0)
        lam = 0.75
        bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam)
        input[:, :, bbx1:bbx2, bby1:bby2] = 0

    return input, target, target_b, lam_a.cuda(), lam_b.cuda()
def cutmix(input, target, conf, model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0)).cuda()
    target_b = target.clone()

    if r < conf.prob:
        bs = input.size(0)
        lam = np.random.beta(conf.beta, conf.beta)
        rand_index = torch.randperm(bs).cuda()
        target_b = target[rand_index]
        input_b = input[rand_index].clone()
        bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam)
        input[:, :, bbx1:bbx2, bby1:bby2] = input_b[:, :, bbx1:bbx2, bby1:bby2]

        # adjust lambda to exactly match pixel ratio
        lam_a = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                     (input.size()[-1] * input.size()[-2]))
        lam_a *= torch.ones(input.size(0))

    lam_b = 1 - lam_a

    return input, target, target_b, lam_a.cuda(), lam_b.cuda()
示例#6
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    current_LR = utils.get_learning_rate(optimizer)[0]
    for i, (imgs, labels, bbox) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        imgs = imgs.cuda()
        labels = labels.cuda()

        # r = np.random.rand(1)
        if args.beta > 0 and args.mix_prob > 0:
            # generate mixed sample
            is_train = True
            lam = np.random.beta(args.beta, args.beta)
            rand_index = torch.randperm(imgs.size()[0]).cuda()

            labels_a = labels
            labels_b = labels[rand_index]

            bbx1, bby1, bbx2, bby2 = utils.rand_bbox(imgs.size(), lam)
            imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2,
                                                    bby1:bby2]

            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                       (imgs.size()[-1] * imgs.size()[-2]))

            # compute output
            output = model(imgs, is_train, rand_index, lam)
            loss = criterion(output, labels_a) * lam + criterion(
                output, labels_b) * (1. - lam)
        else:
            # compute output
            is_train = False

            output = model(imgs, is_train)
            loss = criterion(output, labels)

        # measure accuracy and record loss
        # _, preds = torch.max(output.data, 1)
        err1, err5 = utils.accuracy(output.data, labels, topk=(1, 5))

        losses.update(loss.item(), imgs.size(0))
        top1.update(err1.item(), imgs.size(0))
        top5.update(err5.item(), imgs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.verbose == True:
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'LR: {LR:.6f}\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1-acc {top1.val:.4f} ({top1.avg:.4f})\t'
                  'Top 5-acc {top5.val:.4f} ({top5.avg:.4f})'.format(
                      epoch,
                      args.epochs + start_epoch,
                      i,
                      len(train_loader),
                      LR=current_LR,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    print(
        '* Epoch: [{0}/{1}]\t Top 1-acc {top1.avg:.3f}  Top 5-acc {top5.avg:.3f}\t Train Loss {loss.avg:.3f} \n'
        .format(epoch,
                args.epochs + start_epoch,
                top1=top1,
                top5=top5,
                loss=losses))

    return top1.avg, losses.avg
示例#7
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    current_LR = get_learning_rate(optimizer)[0]
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        input = input.cuda()
        target = target.cuda()

        r = np.random.rand(1)
        if args.beta > 0 and r < args.mix_prob:
            # generate mixed sample
            is_train = True
            lam = np.random.beta(args.beta, args.beta)
            rand_index = torch.randperm(input.size()[0]).cuda()

            target_a = target
            target_b = target[rand_index]

            bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
            input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2,
                                                      bby1:bby2]

            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                       (input.size()[-1] * input.size()[-2]))

            # compute output
            input_var = torch.autograd.Variable(input, requires_grad=True)
            target_a_var = torch.autograd.Variable(target_a)
            target_b_var = torch.autograd.Variable(target_b)
            output = model(input_var, is_train, rand_index, lam)
            loss = criterion(output, target_a_var) * lam + criterion(
                output, target_b_var) * (1. - lam)
        else:
            # compute output
            is_train = False
            input_var = torch.autograd.Variable(input, requires_grad=True)
            target_var = torch.autograd.Variable(target)
            output = model(input_var, is_train)
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        err1, err5 = accuracy(output.data, target, topk=(1, 5))

        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.verbose == True:
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'LR: {LR:.6f}\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
                  'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format(
                      epoch,
                      args.epochs + start_epoch,
                      i,
                      len(train_loader),
                      LR=current_LR,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    print(
        '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f}  Top 5-err {top5.avg:.3f}\t Train Loss {loss.avg:.3f} \n'
        .format(epoch,
                args.epochs + start_epoch,
                top1=top1,
                top5=top5,
                loss=losses))

    return losses.avg, top1.val
示例#8
0
    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        data, target = batch
        data, target = data.to(self.device), target.to(self.device)

        if self.cfg.INPUT.USE_MIX_UP:
            data, target_a, target_b, lam = mixup_data(data, target, 0.4, True)
        self.use_cut_mix = False
        if self.cfg.INPUT.USE_RICAP:
            I_x, I_y = input.size()[2:]

            w = int(
                np.round(I_x *
                         np.random.beta(args.ricap_beta, args.ricap_beta)))
            h = int(
                np.round(I_y *
                         np.random.beta(args.ricap_beta, args.ricap_beta)))
            w_ = [w, I_x - w, w, I_x - w]
            h_ = [h, h, I_y - h, I_y - h]

            cropped_images = {}
            c_ = {}
            W_ = {}
            for k in range(4):
                idx = torch.randperm(input.size(0))
                x_k = np.random.randint(0, I_x - w_[k] + 1)
                y_k = np.random.randint(0, I_y - h_[k] + 1)
                cropped_images[k] = input[idx][:, :, x_k:x_k + w_[k],
                                               y_k:y_k + h_[k]]
                c_[k] = target[idx].cuda()
                W_[k] = w_[k] * h_[k] / (I_x * I_y)

            patched_images = torch.cat((torch.cat(
                (cropped_images[0], cropped_images[1]),
                2), torch.cat((cropped_images[2], cropped_images[3]), 2)), 3)
            data = patched_images.to(self.device)

        if self.cfg.INPUT.USE_CUT_MIX:
            r = np.random.rand(1)
            if r < 0.5:
                self.use_cut_mix = True
                lam = np.random.beta(1.0, 1.0)
                rand_index = torch.randperm(data.size()[0]).cuda()
                target_a = target
                target_b = target[rand_index]
                bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
                data[:, :, bbx1:bbx2, bby1:bby2] = data[rand_index, :,
                                                        bbx1:bbx2, bby1:bby2]
                # adjust lambda to exactly match pixel ratio
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                           (data.size()[-1] * data.size()[-2]))
                # compute output
                data = torch.autograd.Variable(data, requires_grad=True)
                target_a_var = torch.autograd.Variable(target_a)
                target_b_var = torch.autograd.Variable(target_b)

        outputs = self.model(data)

        # loss = self.loss_func(outputs, target)
        if self.cfg.INPUT.USE_RICAP:
            loss = sum(
                [W_[k] * self.loss_func(outputs, c_[k]) for k in range(4)])
        elif self.cfg.INPUT.USE_MIX_UP:
            loss1 = self.loss_func(outputs, target_a)
            loss2 = self.loss_func(outputs, target_b)
            loss = lam * loss1 + (1 - lam) * loss2
        elif self.cfg.INPUT.USE_CUT_MIX and self.use_cut_mix:
            loss1 = self.loss_func(outputs, target_a_var)
            loss2 = self.loss_func(outputs, target_b_var)
            loss = lam * loss1 + (1 - lam) * loss2
        else:
            loss = self.loss_func(outputs, target)

        if self.current_iteration % self.cfg.SOLVER.TENSORBOARD.LOG_PERIOD == 0:
            if self.summary_writer:
                self.summary_writer.add_scalar('Train/loss', loss,
                                               self.current_iteration)
        loss.backward()
        self.optim.step()

        if type(outputs) == type(()) and len(outputs) > 1:
            _output = outputs[0]
            for output in outputs[1:]:
                _output = _output + output
            outputs = _output / len(outputs)

        target = target.data.cpu()
        outputs = outputs.data.cpu()

        f1, acc = calculate_score(self.cfg, outputs, target)

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc)
        self.f1_avg.update(f1)

        return self.loss_avg.avg, self.acc_avg.avg, self.f1_avg.avg
示例#9
0
文件: train.py 项目: szq0214/MEAL-V2
def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader, optimizer, epoch_number, args):
    model.train()
    g_loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    g_loss_meter = utils.AverageMeter(recent=100)
    d_loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        if args.w_cutmix == True:
            r = np.random.rand(1)
            if args.beta > 0 and r < args.cutmix_prob:
                # generate mixed sample
                lam = np.random.beta(args.beta, args.beta)
                rand_index = torch.randperm(images.size()[0]).cuda()
                target_a = labels
                target_b = labels[rand_index]
                bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam)
                images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]

        # Forward pass, backward pass, and update parameters.
        outputs = model(images, before=True)
        output, soft_label, soft_no_softmax = outputs
        g_loss_output = g_loss((output, soft_label), labels)
        d_loss_value = discriminator_loss([output], [soft_no_softmax])

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(g_loss_output, tuple):
            g_loss_value, outputs = g_loss_output
        else:
            g_loss_value = g_loss_output

        loss_value = g_loss_value + d_loss_value

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        g_loss_meter.update(g_loss_value.item(), batch_size)
        d_loss_meter.update(d_loss_value.item(), batch_size)

        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        if i%20 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}}    '
                'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(
                    epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader),
                    batch_time=batch_time_meter, data_time=data_time_meter,
                    g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter,
                    lr=_get_learning_rate(optimizer)))
    # Log the overall train stats
    logging.info(
        'Epoch: [{epoch}] -- TRAINING SUMMARY\t'
        'Time {batch_time.sum:.2f}   '
        'Data {data_time.sum:.2f}   '
        'G_Loss {g_loss.average:.3f}     '
        'D_Loss {d_loss.average:.3f}     '
        'Top-1 {top1.average:.2f}    '
        'Top-5 {top5.average:.2f}    '.format(
            epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
            g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter))
示例#10
0
def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader,
                        optimizer, epoch_number, args):
    model.train()
    g_loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    g_loss_meter = utils.AverageMeter(recent=100)
    d_loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels, soft_labels) in enumerate(train_loader):
        batch_size = args.batch_size

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        images = torch.cat(images, dim=0)
        soft_labels = torch.cat(soft_labels, dim=0)
        labels = torch.cat(labels, dim=0)

        if args.soft_label_type == 'ori':
            soft_labels = soft_labels.cuda()
        else:
            soft_labels = Recover_soft_label(soft_labels, args.soft_label_type,
                                             args.num_classes)
            soft_labels = soft_labels.cuda()

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        if args.w_cutmix == True:
            r = np.random.rand(1)
            if args.beta > 0 and r < args.cutmix_prob:
                # generate mixed sample
                lam = np.random.beta(args.beta, args.beta)
                rand_index = torch.randperm(images.size()[0]).cuda()
                target_a = soft_labels
                target_b = soft_labels[rand_index]
                bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam)
                images[:, :, bbx1:bbx2,
                       bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                           (images.size()[-1] * images.size()[-2]))

        # Forward pass, backward pass, and update parameters.
        output = model(images)
        # output, soft_label, soft_no_softmax = outputs
        if args.w_cutmix == True:
            g_loss_output1 = g_loss((output, target_a), labels)
            g_loss_output2 = g_loss((output, target_b), labels)
        else:
            g_loss_output = g_loss((output, soft_labels), labels)
        if args.use_discriminator_loss:
            # Our stored label is "after softmax", this is slightly different from original MEAL V2
            # that used probibilaties "before softmax" for the discriminator.
            output_softmax = nn.functional.softmax(output)
            if args.w_cutmix == True:
                d_loss_value = discriminator_loss(
                    [output_softmax], [target_a]) * lam + discriminator_loss(
                        [output_softmax], [target_b]) * (1 - lam)
            else:
                d_loss_value = discriminator_loss([output_softmax],
                                                  [soft_labels])

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if args.w_cutmix == True:
            if isinstance(g_loss_output1, tuple):
                g_loss_value1, output1 = g_loss_output1
                g_loss_value2, output2 = g_loss_output2
                g_loss_value = g_loss_value1 * lam + g_loss_value2 * (1 - lam)
            else:
                g_loss_value = g_loss_output1 * lam + g_loss_output2 * (1 -
                                                                        lam)
        else:
            if isinstance(g_loss_output, tuple):
                g_loss_value, output = g_loss_output
            else:
                g_loss_value = g_loss_output

        if args.use_discriminator_loss:
            loss_value = g_loss_value + d_loss_value
        else:
            loss_value = g_loss_value

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        g_loss_meter.update(g_loss_value.item(), batch_size)
        d_loss_meter.update(d_loss_value.item(), batch_size)

        top1, top5 = utils.topk_accuracy(output, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        if i % 20 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}}    '
                'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(epoch=epoch_number,
                                     batch=i + 1,
                                     epoch_size=len(train_loader),
                                     batch_time=batch_time_meter,
                                     data_time=data_time_meter,
                                     g_loss=g_loss_meter,
                                     d_loss=d_loss_meter,
                                     top1=top1_meter,
                                     top5=top5_meter,
                                     lr=_get_learning_rate(optimizer)))
    # Log the overall train stats
    logging.info('Epoch: [{epoch}] -- TRAINING SUMMARY\t'
                 'Time {batch_time.sum:.2f}   '
                 'Data {data_time.sum:.2f}   '
                 'G_Loss {g_loss.average:.3f}     '
                 'D_Loss {d_loss.average:.3f}     '
                 'Top-1 {top1.average:.2f}    '
                 'Top-5 {top5.average:.2f}    '.format(
                     epoch=epoch_number,
                     batch_time=batch_time_meter,
                     data_time=data_time_meter,
                     g_loss=g_loss_meter,
                     d_loss=d_loss_meter,
                     top1=top1_meter,
                     top5=top5_meter))