def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=False): losses = AverageMeter() accs = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() if fixbase or args.always_fixbase: open_specified_layers(model, args.open_layers) else: open_all_layers(model) end = time.time() for batch_idx, (imgs, pids, _, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs = model(imgs) if isinstance(outputs, (tuple, list)): loss = DeepSupervision(criterion, outputs, pids) else: loss = criterion(outputs, pids) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) accs.update(accuracy(outputs, pids)[0]) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accs)) end = time.time()
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() if freeze_bn or args.freeze_bn: model.apply(set_bn_to_eval) end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs = model(imgs) if isinstance(outputs, tuple): loss = DeepSupervision(criterion, outputs, pids) else: loss = criterion(outputs, pids) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time()
def compute_loss(self, criterion, outputs, targets): if isinstance(outputs, (tuple, list)): loss = DeepSupervision(criterion, outputs, targets) else: loss = criterion(outputs, targets) #loss = DeepSupervision(criterion, outputs, targets) return loss
def _compute_loss(self, criterion, outputs, targets, part_weights=None): if isinstance(outputs, (tuple, list)): loss = DeepSupervision(criterion, outputs, targets, part_weights=part_weights) else: loss = criterion(outputs, targets) return loss
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs, features = model(imgs) if args.htri_only: if isinstance(features, tuple): loss = DeepSupervision(criterion_htri, features, pids) else: loss = criterion_htri(features, pids) else: if isinstance(outputs, tuple): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) if isinstance(features, tuple): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time()
def train(epoch, model, keyptaware, multitask, criterion_xent_vid, criterion_xent_vcolor, criterion_xent_vtype, criterion_htri, optimizer, trainloader, use_gpu): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() end = time.time() for batch_idx, (imgs, vids, camids, vcolors, vtypes, vkeypts) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: if keyptaware and multitask: imgs, vids, vcolors, vtypes, vkeypts = imgs.cuda(), vids.cuda( ), vcolors.cuda(), vtypes.cuda(), vkeypts.cuda() elif keyptaware: imgs, vids, vkeypts = imgs.cuda(), vids.cuda(), vkeypts.cuda() elif multitask: imgs, vids, vcolors, vtypes = imgs.cuda(), vids.cuda( ), vcolors.cuda(), vtypes.cuda() else: imgs, vids = imgs.cuda(), vids.cuda() if keyptaware and multitask: output_vids, output_vcolors, output_vtypes, features = model( imgs, vkeypts) elif keyptaware: output_vids, features = model(imgs, vkeypts) elif multitask: output_vids, output_vcolors, output_vtypes, features = model(imgs) else: output_vids, features = model(imgs) if args.htri_only: if isinstance(features, tuple): loss = DeepSupervision(criterion_htri, features, vids) else: loss = criterion_htri(features, vids) else: if isinstance(output_vids, tuple): xent_loss = DeepSupervision(criterion_xent_vid, output_vids, vids) else: xent_loss = criterion_xent_vid(output_vids, vids) if isinstance(features, tuple): htri_loss = DeepSupervision(criterion_htri, features, vids) else: htri_loss = criterion_htri(features, vids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss if multitask: if isinstance(output_vcolors, tuple): xent_loss_vcolor = DeepSupervision(criterion_xent_vcolor, output_vcolors, vcolors) else: xent_loss_vcolor = criterion_xent_vcolor( output_vcolors, vcolors) if isinstance(output_vtypes, tuple): xent_loss_vtype = DeepSupervision(criterion_xent_vtype, output_vtypes, vtypes) else: xent_loss_vtype = criterion_xent_vtype(output_vtypes, vtypes) loss += args.lambda_vcolor * xent_loss_vcolor + args.lambda_vtype * xent_loss_vtype optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), vids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time()
def train(epoch, model, model_decoder, criterion_xent, criterion_htri, optimizer, optimizer_decoder, optimizer_encoder, trainloader, use_gpu, fixbase=False): losses = AverageMeter() losses_recon = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() model_decoder.train() if fixbase or args.fixbase: open_specified_layers(model, args.open_layers) else: open_all_layers(model) end = time.time() for batch_idx, (imgs, pids, _, img_paths, imgs_texture) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids, imgs_texture = imgs.cuda(), pids.cuda( ), imgs_texture.cuda() outputs, features, feat_texture, x_down1, x_down2, x_down3 = model( imgs) torch.cuda.empty_cache() if args.htri_only: if isinstance(features, (tuple, list)): loss = DeepSupervision(criterion_htri, features, pids) else: loss = criterion_htri(features, pids) else: if isinstance(outputs, (tuple, list)): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) if isinstance(features, (tuple, list)): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() del outputs, features # Second forward for training texture reconstruction close_specified_layers(model, ['fc', 'classifier']) recon_texture, x_sim1, x_sim2, x_sim3, x_sim4 = model_decoder( feat_texture, x_down1, x_down2, x_down3) torch.cuda.empty_cache() loss_rec = nn.L1Loss() loss_tri = nn.MSELoss() loss_recon = loss_rec(recon_texture, imgs_texture) #*0.1 # L1 loss to push same id's feat more similar: loss_triplet_id_sim1 = 0.0 loss_triplet_id_sim2 = 0.0 loss_triplet_id_sim3 = 0.0 loss_triplet_id_sim4 = 0.0 for i in range(0, ((args.train_batch_size // args.num_instances) - 1) * args.num_instances, args.num_instances): loss_triplet_id_sim1 += max( loss_tri(x_sim1[i], x_sim1[i + 1]) - loss_tri(x_sim1[i], x_sim1[i + 4]) + 0.3, 0.0) loss_triplet_id_sim2 += max( loss_tri(x_sim2[i + 1], x_sim2[i + 2]) - loss_tri(x_sim2[i + 1], x_sim2[i + 5]) + 0.3, 0.0) #loss_tri(x_sim2[i+1], x_sim2[i+2]) loss_triplet_id_sim3 += max( loss_tri(x_sim3[i + 2], x_sim3[i + 3]) - loss_tri(x_sim3[i + 2], x_sim3[i + 6]) + 0.3, 0.0) #loss_tri(x_sim3[i+2], x_sim3[i+3]) loss_triplet_id_sim4 += max( loss_tri(x_sim4[i], x_sim4[i + 3]) - loss_tri(x_sim4[i + 3], x_sim4[i + 4]) + 0.3, 0.0) #loss_tri(x_sim4[i], x_sim4[i+3]) loss_same_id = loss_triplet_id_sim1 + loss_triplet_id_sim2 + loss_triplet_id_sim3 + loss_triplet_id_sim4 loss_recon += (loss_same_id) # * 0.0001) optimizer_encoder.zero_grad() optimizer_decoder.zero_grad() loss_recon.backward() optimizer_encoder.step() optimizer_decoder.step() del feat_texture, x_down1, x_down2, x_down3, recon_texture batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) losses_recon.update(loss_recon.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss_recon {loss_recon.val:.4f} ({loss_recon.avg:.4f})\t'. format(epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses, loss_recon=losses_recon)) end = time.time() open_all_layers(model) if (epoch + 1) % 50 == 0: print("==> Test reconstruction effect") model.eval() model_decoder.eval() features, feat_texture = model(imgs) recon_texture = model_decoder(feat_texture) out = recon_texture.data.cpu().numpy()[0].squeeze() out = out.transpose((1, 2, 0)) out = (out / 2.0 + 0.5) * 255. out = out.astype(np.uint8) print( 'finish: ', os.path.join( args.save_dir, img_paths[0].split('bounding_box_train/') [-1].split('.jpg')[0] + 'ep_' + str(epoch) + '.jpg')) cv2.imwrite( os.path.join( args.save_dir, img_paths[0].split('bounding_box_train/') [-1].split('.jpg')[0] + 'ep_' + str(epoch) + '.jpg'), out[:, :, ::-1]) model.train() model_decoder.train()
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, writer, args, freeze_bn=False): losses = AverageMeter() xent_losses = AverageMeter() htri_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() printed = False model.train() if freeze_bn or args.freeze_bn: model.apply(set_bn_to_eval) end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs, features = model(imgs) if args.htri_only: if isinstance(features, tuple): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_htri * htri_loss else: if isinstance(outputs, tuple): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) htri_loss = 0 if not (criterion_htri is None): if isinstance(features, tuple): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) if not args.htri_only: xent_losses.update(xent_loss.item(), pids.size(0)) if criterion_htri is None: htri_losses.update(htri_loss, pids.size(0)) else: htri_losses.update(htri_loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: if not printed: printed = True else: # Clean the current line sys.stdout.console.write("\033[F\033[K") #sys.stdout.console.write("\033[K") print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent Loss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t' 'Htri Loss {htri_loss.val:.4f} ({htri_loss.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent_loss=xent_losses, htri_loss=htri_losses, loss=losses)) end = time.time() writer.add_scalars( 'Losses', dict(total_loss=losses.avg, xen_loss=xent_losses.avg, htri_loss=htri_losses.avg), epoch + 1)
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer, args, freeze_bn=False): losses = AverageMeter() xent_losses = AverageMeter() confidence_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() printed = False model.train() if freeze_bn or args.freeze_bn: model.apply(set_bn_to_eval) end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs = model(imgs) if isinstance(outputs, tuple): xent_loss = DeepSupervision(criterion[0], outputs, pids) confidence_loss = DeepSupervision(criterion[1], outputs, pids) else: xent_loss = criterion[0](outputs, pids) confidence_loss = criterion[1](outputs, pids) if args.confidence_penalty: loss = args.lambda_xent * xent_loss - args.confidence_beta * confidence_loss elif args.jsd: loss = args.lambda_xent * xent_loss + args.confidence_beta * confidence_loss else: loss = args.lambda_xent * xent_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) xent_losses.update(xent_loss.item(), pids.size(0)) confidence_losses.update(confidence_loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: if not printed: printed = True else: # Clean the current line sys.stdout.console.write("\033[F\033[K") #sys.stdout.console.write("\033[K") if args.jsd: print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent_Loss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t' 'JSD_Loss {confidence_loss.val:.4f} ({confidence_loss.avg:.4f})\t' 'Total_Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent_loss=xent_losses, confidence_loss=confidence_losses, loss=losses)) else: print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent_Loss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t' 'Confi_Loss {confidence_loss.val:.4f} ({confidence_loss.avg:.4f})\t' 'Total_Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent_loss=xent_losses, confidence_loss=confidence_losses, loss=losses)) end = time.time() writer.add_scalars( 'loss', dict(loss=losses.avg, xent_loss=xent_losses.avg, confidence_loss=confidence_losses.avg), epoch + 1)
def train(epoch, model, criterion, center_loss1, center_loss2, center_loss3, center_loss4, optimizer, trainloader, use_gpu, fixbase=False): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() if fixbase or args.always_fixbase: open_specified_layers(model, args.open_layers) else: open_all_layers(model) end = time.time() for batch_idx, (imgs, pids, _, _, dataset_id) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids, dataset_id = imgs.cuda(), pids.cuda(), dataset_id.cuda( ) outputs, features = model(imgs) if isinstance(outputs, (tuple, list)): loss = DeepSupervision(criterion, outputs, pids) else: loss = criterion(outputs, pids) alpha = 0.001 loss = center_loss1(features[0], dataset_id) * alpha + loss loss = center_loss2(features[1], dataset_id) * alpha + loss # belta = 0.0001 belta = 0.00001 loss = center_loss3(features[0], pids) * belta + loss loss = center_loss4(features[1], pids) * belta + loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time()
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer, args, freeze_bn=False): losses = AverageMeter() xent_losses = AverageMeter() info_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() printed = False model.train() if freeze_bn or args.freeze_bn: model.apply(set_bn_to_eval) end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() (mu, std), outputs = model(imgs) if isinstance(outputs, tuple): xent_loss = DeepSupervision(criterion, outputs, pids) else: xent_loss = criterion(outputs, pids) info_loss = -0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum(1).mean().div(math.log(2)) loss = args.lambda_xent * xent_loss + args.beta * info_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) xent_losses.update(xent_loss.item(), pids.size(0)) info_losses.update(info_loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: if not printed: printed = True else: # Clean the current line sys.stdout.console.write("\033[F\033[K") #sys.stdout.console.write("\033[K") print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent_Loss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t' 'Info_Loss {info_loss.val:.4f} ({info_loss.avg:.4f})\t' 'Total_Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent_loss=xent_losses, info_loss=info_losses, loss=losses)) end = time.time() writer.add_scalars( 'loss', dict(loss=losses.avg, xent_loss=xent_losses.avg, info_loss=info_losses.avg), epoch + 1)
def train(epoch, model, criterion, regularizer, optimizer, trainloader, use_gpu, fixbase=False, switch_loss=False): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() if fixbase or args.fixbase: open_specified_layers(model, args.open_layers) else: open_all_layers(model) end = time.time() for batch_idx, (imgs, pids, _, _) in enumerate(trainloader): try: limited = float(os.environ.get('limited', None)) except (ValueError, TypeError): limited = 1 # print('################# limited', limited) if not fixbase and (batch_idx + 1) > limited * len(trainloader): break data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs = model(imgs) if False and isinstance(outputs, (tuple, list)): loss = DeepSupervision(criterion, outputs, pids) else: loss = criterion(outputs, pids) print(loss) # if True or (fixbase and args.fix_custom_loss) or not fixbase and ((switch_loss and args.switch_loss < 0) or (not switch_loss and args.switch_loss > 0)): if not fixbase: reg = regularizer(model) # print('use reg', reg) # print('use reg', reg) loss += reg optimizer.zero_grad() loss.backward() if args.use_clip_grad and (args.switch_loss < 0 and switch_loss): print('Clip!') torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) del loss del outputs if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time()
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer, args, freeze_bn=False): losses = AverageMeter() xent_losses = AverageMeter() info_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() confidence_losses = AverageMeter() printed = False model.train() if freeze_bn or args.freeze_bn: model.apply(set_bn_to_eval) end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() (mu, std), outputs = model(imgs) text_dict = {} if not isinstance(criterion, MultiHeadLossAutoTune): if isinstance(outputs, tuple): xent_loss = DeepSupervision(criterion[0], outputs, pids) confidence_loss = DeepSupervision(criterion[1], outputs, pids) else: xent_loss = criterion[0](outputs, pids) confidence_loss = criterion[1](outputs, pids) info_loss = criterion[-1](mu.float(), std.float()) if args.confidence_penalty: loss = args.lambda_xent * xent_loss + args.beta * info_loss - args.confidence_beta * confidence_loss elif args.jsd: loss = args.lambda_xent * xent_loss + args.beta * info_loss + args.confidence_beta * confidence_loss else: loss = args.lambda_xent * xent_loss + args.beta * info_loss confidence_losses.update(confidence_loss.item(), pids.size(0)) else: if args.confidence_penalty or args.jsd: loss, individual_losses = criterion([outputs, outputs, mu], [pids, pids, std]) confidence_loss = individual_losses[1] else: loss, individual_losses = criterion([outputs, mu], [pids, std]) confidence_loss = 0 xent_loss = individual_losses[0] info_loss = individual_losses[-1] text_dict = criterion.batch_meta() confidence_losses.update(0, pids.size(0)) #info_loss = -0.5*(1-2*std.log()-(1+mu.pow(2))/(2*std.pow(2))).sum(1).mean().div(math.log(2)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) xent_losses.update(xent_loss.item(), pids.size(0)) info_losses.update(info_loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: if not printed: printed = True else: # Clean the current line sys.stdout.console.write("\033[F\033[K") #sys.stdout.console.write("\033[K") if args.jsd: print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent_Loss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t' 'JSD_Loss {confidence_loss.val:.4f} ({confidence_loss.avg:.4f})\t' 'Info_Loss {info_loss.val:.4f} ({info_loss.avg:.4f})\t' 'Total_Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent_loss=xent_losses, confidence_loss=confidence_losses, loss=losses), text_dict) else: print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent_Loss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t' 'Confi_Loss {confidence_loss.val:.4f} ({confidence_loss.avg:.4f})\t' 'Info_Loss {info_loss.val:.4f} ({info_loss.avg:.4f})\t' 'Total_Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent_loss=xent_losses, confidence_loss=confidence_losses, info_loss=info_losses, loss=losses), text_dict) end = time.time() writer.add_scalars( 'loss', dict(loss=losses.avg, xent_loss=xent_losses.avg, info_loss=info_losses.avg, confidence_loss=confidence_losses.avg), epoch + 1)
def _compute_loss(self, criterion, outputs, targets): if isinstance(outputs, (list)): # delete tuple here loss = DeepSupervision(criterion, outputs, targets) else: loss = criterion(outputs, targets) return loss
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, writer=None): xent_losses = AverageMeter() htri_losses = AverageMeter() precisions = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() end = time.time() for batch_idx, (imgs, pids, _, adj) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids, adj = imgs.cuda(), pids.cuda(), adj.cuda() outputs, features = model(imgs, adj) if isinstance(outputs, tuple) or isinstance(outputs, list): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) if isinstance(features, tuple) or isinstance(features, list): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) xent_losses.update(xent_loss.item(), pids.size(0)) htri_losses.update(htri_loss.item(), pids.size(0)) precisions.update(metrics.accuracy(outputs, pids).mean(axis=0)[0]) if ((batch_idx + 1) % args.print_freq == 0) or (args.print_last and batch_idx == (len(trainloader) - 1)): num_batches = len(trainloader) eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) + (args.max_epoch - (epoch + 1)) * num_batches) eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) print('CurTime: {0}\t' 'Epoch: [{1}][{2}/{3}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {speed:.3f} samples/s\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent {xent.val:.4f} ({xent.avg:.4f})\t' 'Htri {htri.val:.4f} ({htri.avg:.4f})\t' 'Top1 {prec.val:.4f} ({prec.avg:.4f})\t' 'Eta {eta}'.format(cur_time(), epoch + 1, batch_idx + 1, len(trainloader), speed=1 / batch_time.avg * imgs.shape[0], batch_time=batch_time, data_time=data_time, xent=xent_losses, htri=htri_losses, prec=precisions, eta=eta_str)) end = time.time() writer.add_scalar(tag='loss/xent_loss', scalar_value=xent_losses.avg, global_step=epoch + 1) writer.add_scalar(tag='loss/htri_loss', scalar_value=htri_losses.avg, global_step=epoch + 1)