def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
    xent_losses = AverageMeter()
    htri_losses = AverageMeter()
    accs = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()

    if fixbase or args.always_fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
        data_time.update(time.time() - end)
        
        if use_gpu:
            imgs, pids = imgs.cuda(), pids.cuda()
        
        outputs, features = model(imgs)
        if isinstance(outputs, (tuple, list)):
            xent_loss = DeepSupervision(criterion_xent, outputs, pids)
        else:
            xent_loss = criterion_xent(outputs, pids)
        
        if isinstance(features, (tuple, list)):
            htri_loss = DeepSupervision(criterion_htri, features, pids)
        else:
            htri_loss = criterion_htri(features, pids)
            
        loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        xent_losses.update(xent_loss.item(), pids.size(0))
        htri_losses.update(htri_loss.item(), pids.size(0))
        accs.update(accuracy(outputs, pids)[0])

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
                  'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
                  'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
                   epoch + 1, batch_idx + 1, len(trainloader),
                   batch_time=batch_time,
                   data_time=data_time,
                   xent=xent_losses,
                   htri=htri_losses,
                   acc=accs
            ))
        
        end = time.time()
Example #2
0
def train(epoch,
          max_epoch,
          model,
          criterion,
          optimizer,
          trainloader,
          fixbase_epoch=0,
          open_layers=None):
    losses = AverageMeter()
    accs = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    model.train()
    if (epoch + 1) <= fixbase_epoch and open_layers is not None:
        print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch + 1,
                                                      fixbase_epoch))
        open_specified_layers(model, open_layers)
    else:
        open_all_layers(model)
    num_batches = len(trainloader)
    end = time.time()
    for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
        data_time.update(time.time() - end)
        imgs = imgs.cuda()
        pids = pids.cuda()
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, pids)
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        losses.update(loss.item(), pids.size(0))
        accs.update(accuracy(outputs, pids)[0].item())
        if (batch_idx + 1) % 20 == 0:
            eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                            (max_epoch -
                                             (epoch + 1)) * num_batches)
            eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                  'Lr {lr:.6f}\t'
                  'eta {eta}'.format(epoch + 1,
                                     max_epoch,
                                     batch_idx + 1,
                                     num_batches,
                                     batch_time=batch_time,
                                     data_time=data_time,
                                     loss=losses,
                                     acc=accs,
                                     lr=optimizer.param_groups[0]['lr'],
                                     eta=eta_str))
        end = time.time()
Example #3
0
def train(epoch,
          model,
          criterion,
          optimizer,
          trainloader,
          use_gpu,
          fixbase=False):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()

    if fixbase or args.fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
        data_time.update(time.time() - end)

        if use_gpu:
            imgs, pids = imgs.cuda(), pids.cuda()

        outputs = model(imgs)
        print(type(outputs), len(outputs), type(pids), len(pids))
        if False and isinstance(outputs, (tuple, list)):
            loss = DeepSupervision(criterion, outputs, pids)
        else:
            loss = criterion(outputs, pids)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss.item(), pids.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      epoch + 1,
                      batch_idx + 1,
                      len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

        end = time.time()
def train_base(model):

    use_sgd = os.environ.get('sgd') is not None

    optimizer_getter = get_base_sgd_optimizer if use_sgd else get_base_optimizer

    optimizer, scheduler = get_base_optimizer(model)

    model.train()
    print('=== train base ===')

    if True:
        open_layers = ['fc', 'classifier1', 'classifier2_1', 'classifier2_2', 'fc2_1', 'fc2_2', 'reduction', 'classifier']

        print('Train {} for {} epochs while keeping other layers frozen'.format(open_layers, 10))

        for epoch in range(10):

            open_specified_layers(model, open_layers)
            train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)

    print('Done. All layers are open to train for {} epochs'.format(60))
    open_all_layers(model)

    optimizer, scheduler = optimizer_getter(model)

    for epoch in range(60):
        train(epoch, model, criterion, optimizer, trainloader, use_gpu=use_gpu)
        scheduler.step()

        print('=> Test')

        if (epoch + 1) % args.eval_freq == 0:

            for name in args.target_names:
                print('Evaluating {} ...'.format(name))
                queryloader = testloader_dict[name]['query']
                galleryloader = testloader_dict[name]['gallery']
                rank1 = test(model, queryloader, galleryloader, use_gpu)

    save_checkpoint({
        'state_dict': model.state_dict(),
        'rank1': rank1,
        'epoch': 0,
        'arch': args.arch,
        'optimizer': optimizer.state_dict(),
    }, args.save_dir, prefix='base_')
Example #5
0
def train(epoch,
          model,
          criterion,
          regularizer,
          optimizer,
          trainloader,
          use_gpu,
          fixbase=False):
    start_train_time = time.time()
    if not fixbase and args.use_of and epoch >= args.of_start_epoch:
        print('Using OF')

    from torchreid.losses.of_penalty import OFPenalty

    of_penalty = OFPenalty(vars(args))

    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()

    if fixbase or args.fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):

        try:
            limited = float(os.environ.get('limited', None))
        except (ValueError, TypeError):
            limited = 1

        if not fixbase and (batch_idx + 1) > limited * len(trainloader):
            break

        data_time.update(time.time() - end)

        if use_gpu:
            imgs, pids = imgs.cuda(), pids.cuda()

        outputs = model(imgs)
        loss = criterion(outputs, pids)
        if not fixbase:
            reg = regularizer(model)
            loss += reg
        if not fixbase and args.use_of and epoch >= args.of_start_epoch:
            penalty = of_penalty(outputs)
            loss += penalty

        optimizer.zero_grad()

        if use_apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss.item(), pids.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      epoch + 1,
                      batch_idx + 1,
                      len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))
        end = time.time()
    epoch_time = time.time() - start_train_time
    print(f"epoch_time:{epoch_time // 60} min {int(epoch_time) % 60}s")
    return losses.avg
Example #6
0
def train(epoch,
          model,
          model_decoder,
          criterion_xent,
          criterion_htri,
          optimizer,
          optimizer_decoder,
          optimizer_encoder,
          trainloader,
          use_gpu,
          fixbase=False):
    losses = AverageMeter()
    losses_recon = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    model_decoder.train()

    if fixbase or args.fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, img_paths,
                    imgs_texture) in enumerate(trainloader):
        data_time.update(time.time() - end)

        if use_gpu:
            imgs, pids, imgs_texture = imgs.cuda(), pids.cuda(
            ), imgs_texture.cuda()

        outputs, features, feat_texture, x_down1, x_down2, x_down3 = model(
            imgs)
        torch.cuda.empty_cache()

        if args.htri_only:
            if isinstance(features, (tuple, list)):
                loss = DeepSupervision(criterion_htri, features, pids)
            else:
                loss = criterion_htri(features, pids)
        else:
            if isinstance(outputs, (tuple, list)):
                xent_loss = DeepSupervision(criterion_xent, outputs, pids)
            else:
                xent_loss = criterion_xent(outputs, pids)

            if isinstance(features, (tuple, list)):
                htri_loss = DeepSupervision(criterion_htri, features, pids)
            else:
                htri_loss = criterion_htri(features, pids)

            loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        del outputs, features

        # Second forward for training texture reconstruction
        close_specified_layers(model, ['fc', 'classifier'])

        recon_texture, x_sim1, x_sim2, x_sim3, x_sim4 = model_decoder(
            feat_texture, x_down1, x_down2, x_down3)
        torch.cuda.empty_cache()

        loss_rec = nn.L1Loss()
        loss_tri = nn.MSELoss()
        loss_recon = loss_rec(recon_texture, imgs_texture)  #*0.1

        # L1 loss to push same id's feat more similar:
        loss_triplet_id_sim1 = 0.0
        loss_triplet_id_sim2 = 0.0
        loss_triplet_id_sim3 = 0.0
        loss_triplet_id_sim4 = 0.0

        for i in range(0, ((args.train_batch_size // args.num_instances) - 1) *
                       args.num_instances, args.num_instances):
            loss_triplet_id_sim1 += max(
                loss_tri(x_sim1[i], x_sim1[i + 1]) -
                loss_tri(x_sim1[i], x_sim1[i + 4]) + 0.3, 0.0)
            loss_triplet_id_sim2 += max(
                loss_tri(x_sim2[i + 1], x_sim2[i + 2]) -
                loss_tri(x_sim2[i + 1], x_sim2[i + 5]) + 0.3,
                0.0)  #loss_tri(x_sim2[i+1], x_sim2[i+2])
            loss_triplet_id_sim3 += max(
                loss_tri(x_sim3[i + 2], x_sim3[i + 3]) -
                loss_tri(x_sim3[i + 2], x_sim3[i + 6]) + 0.3,
                0.0)  #loss_tri(x_sim3[i+2], x_sim3[i+3])
            loss_triplet_id_sim4 += max(
                loss_tri(x_sim4[i], x_sim4[i + 3]) -
                loss_tri(x_sim4[i + 3], x_sim4[i + 4]) + 0.3,
                0.0)  #loss_tri(x_sim4[i], x_sim4[i+3])
        loss_same_id = loss_triplet_id_sim1 + loss_triplet_id_sim2 + loss_triplet_id_sim3 + loss_triplet_id_sim4

        loss_recon += (loss_same_id)  # * 0.0001)

        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss_recon.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        del feat_texture, x_down1, x_down2, x_down3, recon_texture

        batch_time.update(time.time() - end)

        losses.update(loss.item(), pids.size(0))
        losses_recon.update(loss_recon.item(), pids.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Loss_recon {loss_recon.val:.4f} ({loss_recon.avg:.4f})\t'.
                  format(epoch + 1,
                         batch_idx + 1,
                         len(trainloader),
                         batch_time=batch_time,
                         data_time=data_time,
                         loss=losses,
                         loss_recon=losses_recon))

        end = time.time()
        open_all_layers(model)

        if (epoch + 1) % 50 == 0:
            print("==> Test reconstruction effect")
            model.eval()
            model_decoder.eval()
            features, feat_texture = model(imgs)
            recon_texture = model_decoder(feat_texture)
            out = recon_texture.data.cpu().numpy()[0].squeeze()
            out = out.transpose((1, 2, 0))
            out = (out / 2.0 + 0.5) * 255.
            out = out.astype(np.uint8)
            print(
                'finish: ',
                os.path.join(
                    args.save_dir, img_paths[0].split('bounding_box_train/')
                    [-1].split('.jpg')[0] + 'ep_' + str(epoch) + '.jpg'))
            cv2.imwrite(
                os.path.join(
                    args.save_dir, img_paths[0].split('bounding_box_train/')
                    [-1].split('.jpg')[0] + 'ep_' + str(epoch) + '.jpg'),
                out[:, :, ::-1])
            model.train()
            model_decoder.train()
Example #7
0
def train(epoch,
          model,
          criterions,
          optimizer,
          trainloader,
          use_gpu,
          train_writer,
          fixbase=False,
          lfw=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()

    if fixbase or args.always_fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
        iteration = epoch * len(trainloader) + batch_idx

        data_time.update(time.time() - end)

        if fixbase and batch_idx > 100:
            break

        if use_gpu:
            imgs, pids = imgs.cuda(), pids.cuda()

        outputs, features = model(imgs)

        losses = torch.zeros([1]).cuda()
        kwargs = {'targets': pids, 'imgs': imgs}
        for criterion in criterions:
            inputs = features
            if criterion.name == 'xent' or 'am':
                inputs = outputs
            loss = criterion.weight * criterion.calc_loss(inputs, **kwargs)
            losses += loss
            if np.isnan(loss.item()):
                logged_value = sys.float_info.max
            else:
                logged_value = loss.item()
            criterion.train_stats.update(logged_value, pids.size(0))

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        if (batch_idx + 1) % args.print_freq == 0:
            output_string = 'Epoch: [{0}][{1}/{2}]\t'.format(
                epoch + 1, batch_idx + 1, len(trainloader))
            output_string += 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                batch_time=batch_time)
            output_string += 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'.format(
                data_time=data_time)
            for criterion in criterions:
                output_string += 'Loss {}: {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                    criterion.name, loss=criterion.train_stats)
                train_writer.add_scalar('loss/{}'.format(criterion.name),
                                        criterion.train_stats.val, iteration)
            print(output_string)
        end = time.time()
Example #8
0
def train(epoch,
          model,
          criterion,
          center_loss1,
          center_loss2,
          center_loss3,
          center_loss4,
          optimizer,
          trainloader,
          use_gpu,
          fixbase=False):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()

    if fixbase or args.always_fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, _, dataset_id) in enumerate(trainloader):
        data_time.update(time.time() - end)

        if use_gpu:
            imgs, pids, dataset_id = imgs.cuda(), pids.cuda(), dataset_id.cuda(
            )

        outputs, features = model(imgs)
        if isinstance(outputs, (tuple, list)):
            loss = DeepSupervision(criterion, outputs, pids)
        else:
            loss = criterion(outputs, pids)

        alpha = 0.001
        loss = center_loss1(features[0], dataset_id) * alpha + loss
        loss = center_loss2(features[1], dataset_id) * alpha + loss

        # belta = 0.0001
        belta = 0.00001
        loss = center_loss3(features[0], pids) * belta + loss
        loss = center_loss4(features[1], pids) * belta + loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss.item(), pids.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      epoch + 1,
                      batch_idx + 1,
                      len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

        end = time.time()
Example #9
0
def train(epoch, model, criterion, optimizer, trainloader, writer, use_gpu, fixbase=False):
    losses = AverageMeter()
    precisions = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_iterations = len(trainloader)

    model.train()

    if fixbase or args.always_fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, ((img1, img2), pids, _, _) in enumerate(trainloader):
        data_time.update(time.time() - end)

        if use_gpu:
            img1, img2, pids = img1.cuda(), img2.cuda(), pids.cuda()

        y_large, y_small, y_joint = model(img1, img2)

        loss_batch = args.train_loss_batch_size
        how_many_mini = args.train_batch_size // loss_batch
        for mini_idx in range(how_many_mini):

            start_index = mini_idx * loss_batch
            end_index = start_index + loss_batch

            mini_y_large = y_large[start_index:end_index, :]
            mini_y_small = y_small[start_index:end_index, :]
            mini_y_joint = y_joint[start_index:end_index, :]
            mini_pids = pids[start_index:end_index]

            loss_large = criterion(mini_y_large, mini_pids)
            loss_small = criterion(mini_y_small, mini_pids)
            loss_joint = criterion(mini_y_joint, mini_pids)

            joint_prob = F.softmax(mini_y_joint, dim=1)
            loss_joint_large = criterion(mini_y_large, joint_prob, one_hot=True)
            loss_joint_small = criterion(mini_y_small, joint_prob, one_hot=True)

            total_loss_large = loss_large + loss_joint_large #+
            total_loss_small = loss_small + loss_joint_small #+
            total_loss_joint = loss_joint #+

            prec, = accuracy(mini_y_joint.data, mini_pids.data)
            prec1 = prec[0]  # get top 1

            optimizer.zero_grad()

            # total_loss_large.backward(retain_graph=True)
            # total_loss_small.backward(retain_graph=True)
            # total_loss_joint.backward()
            # sum losses
            loss = total_loss_joint + total_loss_small + total_loss_large
            loss.backward(retain_graph=True)

            optimizer.step()

            loss_iter = epoch*epoch_iterations+batch_idx*how_many_mini+mini_idx
            writer.add_scalar('iter/loss_small', loss_small, loss_iter)
            writer.add_scalar('iter/loss_large', loss_large, loss_iter)
            writer.add_scalar('iter/loss_joint', loss_joint, loss_iter)
            writer.add_scalar('iter/loss_joint_small', loss_joint_small, loss_iter)
            writer.add_scalar('iter/loss_joint_large', loss_joint_large, loss_iter)
            writer.add_scalar('iter/total_loss_small', total_loss_small, loss_iter)
            writer.add_scalar('iter/total_loss_large', total_loss_large, loss_iter)
            writer.add_scalar('iter/total_loss_joint', total_loss_joint, loss_iter)
            writer.add_scalar('iter/loss', loss, loss_iter)


            losses.update(loss.item(), pids.size(0))
            precisions.update(prec1, pids.size(0))

            if (batch_idx*how_many_mini+mini_idx + 1) % args.print_freq == 0:
                print('Epoch: [{0:02d}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec {prec.val:.2%} ({prec.avg:.2%})\t'.format(
                       epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time,
                       data_time=data_time, loss=losses, prec=precisions))

        batch_time.update(time.time() - end)
        end = time.time()

    return losses.avg, precisions.avg
Example #10
0
def train(epoch,
          model,
          criterion,
          regularizer,
          optimizer,
          trainloader,
          use_gpu,
          fixbase=False,
          switch_loss=False):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()

    if fixbase or args.fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):

        try:
            limited = float(os.environ.get('limited', None))
        except (ValueError, TypeError):
            limited = 1
        # print('################# limited', limited)

        if not fixbase and (batch_idx + 1) > limited * len(trainloader):
            break

        data_time.update(time.time() - end)

        if use_gpu:
            imgs, pids = imgs.cuda(), pids.cuda()

        outputs = model(imgs)
        if False and isinstance(outputs, (tuple, list)):
            loss = DeepSupervision(criterion, outputs, pids)
        else:
            loss = criterion(outputs, pids)
        print(loss)
        # if True or (fixbase and args.fix_custom_loss) or not fixbase and ((switch_loss and args.switch_loss < 0) or (not switch_loss and args.switch_loss > 0)):
        if not fixbase:
            reg = regularizer(model)
            # print('use reg', reg)
            # print('use reg', reg)
            loss += reg
        optimizer.zero_grad()
        loss.backward()

        if args.use_clip_grad and (args.switch_loss < 0 and switch_loss):
            print('Clip!')
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)

        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss.item(), pids.size(0))

        del loss
        del outputs

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      epoch + 1,
                      batch_idx + 1,
                      len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

        end = time.time()
def train(epoch,
          model,
          criterion,
          optimizer,
          trainloader,
          writer,
          use_gpu,
          fixbase=False):
    losses = AverageMeter()
    precisions = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_iterations = len(trainloader)

    model.train()

    if fixbase or args.always_fixbase:
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, ((img1, img2), pids, _, _) in enumerate(trainloader):
        data_time.update(time.time() - end)

        if use_gpu:
            img1, img2, pids = img1.cuda(), img2.cuda(), pids.cuda()

        y10, y05, y_consensus = model(img1, img2)

        loss10 = criterion(y10, pids)
        loss05 = criterion(y05, pids)
        loss_consensus = criterion(y_consensus, pids)

        prec, = accuracy(y_consensus.data, pids.data)
        prec1 = prec[0]  # get top 1

        writer.add_scalar('iter/loss', loss_consensus,
                          epoch * epoch_iterations + batch_idx)
        writer.add_scalar('iter/prec1', prec1,
                          epoch * epoch_iterations + batch_idx)

        optimizer.zero_grad()

        loss10.backward(retain_graph=True)
        loss05.backward(retain_graph=True)
        loss_consensus.backward()

        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss_consensus.item(), pids.size(0))
        precisions.update(prec1, pids.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            print('Epoch: [{0:02d}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {prec.val:.2%} ({prec.avg:.2%})\t'.format(
                      epoch + 1,
                      batch_idx + 1,
                      len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      prec=precisions))

        end = time.time()

    return losses.avg, precisions.avg