예제 #1
0
def save_models(net, check_point_file, model_file, image_h_w, onnx_file, opt_level):
    map_location = (lambda storage, loc: storage)
    data = torch.load(check_point_file, map_location=map_location)

    models = dict(data['state_dicts'][0])
    dummy_input = torch.randn(10, 3, image_h_w[0], image_h_w[1], device='cuda').half()
    model = Model(net, pretrained=False)
    model.load_state_dict(models)
    model.cuda()

    optimizer = optim.Adam(model.parameters())

    model, optimizer = amp.initialize(model, optimizer,
                                    opt_level=opt_level,
                                    #loss_scale=cfg.loss_scale
                                )




    input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
    output_names = [ "output1" ]

    torch.save(model.state_dict(), model_file)
    torch.onnx.export(model.base, dummy_input, onnx_file, verbose=True, input_names=input_names, output_names=output_names)
예제 #2
0
def main(opt):
    dataset = VideoDataset(opt, 'test')
    dataloader = DataLoader(dataset,
                            collate_fn=test_collate_fn,
                            batch_size=opt['batch_size'],
                            shuffle=False)
    opt['cms_vocab_size'] = dataset.get_cms_vocab_size()
    opt['cap_vocab_size'] = dataset.get_cap_vocab_size()

    if opt['cms'] == 'int':
        cms_text_length = opt['int_max_len']
    elif opt['cms'] == 'eff':
        cms_text_length = opt['eff_max_len']
    else:
        cms_text_length = opt['att_max_len']

    model = Model(dataset.get_cap_vocab_size(),
                  dataset.get_cms_vocab_size(),
                  cap_max_seq=opt['cap_max_len'],
                  cms_max_seq=cms_text_length,
                  tgt_emb_prj_weight_sharing=True,
                  vis_emb=opt['dim_vis_feat'],
                  rnn_layers=opt['rnn_layer'],
                  d_k=opt['dim_head'],
                  d_v=opt['dim_head'],
                  d_model=opt['dim_model'],
                  d_word_vec=opt['dim_word'],
                  d_inner=opt['dim_inner'],
                  n_layers=opt['num_layer'],
                  n_head=opt['num_head'],
                  dropout=opt['dropout'])

    if len(opt['load_checkpoint']) != 0:
        state_dict = torch.load(opt['load_checkpoint'])
        # for name, param in model.state_dict().items():
        #     print(name, param.size())
        #
        # print('=================')
        # print(state_dict.keys())
        model.load_state_dict(state_dict)

    if opt['cuda']:
        model = model.cuda()

    model.eval()
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params)
    test(dataloader, model, opt, dataset.get_cap_vocab(),
         dataset.get_cms_vocab())
예제 #3
0
def main(opt):
    dataset = VideoDataset(opt, 'test')
    dataloader = DataLoader(dataset,
                            batch_size=opt['batch_size'],
                            shuffle=False)
    opt['cms_vocab_size'] = dataset.get_cms_vocab_size()
    opt['cap_vocab_size'] = dataset.get_cap_vocab_size()

    if opt['cms'] == 'int':
        cms_text_length = opt['int_max_len']
    elif opt['cms'] == 'eff':
        cms_text_length = opt['eff_max_len']
    else:
        cms_text_length = opt['att_max_len']

    model = Model(dataset.get_cap_vocab_size(),
                  dataset.get_cms_vocab_size(),
                  cap_max_seq=opt['cap_max_len'],
                  cms_max_seq=cms_text_length,
                  tgt_emb_prj_weight_sharing=True,
                  vis_emb=opt['dim_vis_feat'],
                  rnn_layers=opt['rnn_layer'],
                  d_k=opt['dim_head'],
                  d_v=opt['dim_head'],
                  d_model=opt['dim_model'],
                  d_word_vec=opt['dim_word'],
                  d_inner=opt['dim_inner'],
                  n_layers=opt['num_layer'],
                  n_head=opt['num_head'],
                  dropout=opt['dropout'])

    if len(opt['load_checkpoint']) != 0:
        state_dict = torch.load(opt['load_checkpoint'])
        model.load_state_dict(state_dict)

    model = model.cuda()
    model.eval()
    test(dataloader, model, opt, dataset.get_cap_vocab(),
         dataset.get_cms_vocab())
예제 #4
0
def main(opt):

    # load and define dataloader
    dataset = VideoDataset(opt, 'train')
    dataloader = DataLoader(dataset, batch_size=opt['batch_size'], shuffle=True)

    opt['cms_vocab_size'] = dataset.get_cms_vocab_size()
    opt['cap_vocab_size'] = dataset.get_cap_vocab_size()

    if opt['cms'] == 'int':
        cms_text_length = opt['int_max_len']
    elif opt['cms'] == 'eff':
        cms_text_length = opt['eff_max_len']
    else:
        cms_text_length = opt['att_max_len']

    # model initialization.
    model = Model(
        dataset.get_cap_vocab_size(),
        dataset.get_cms_vocab_size(),
        cap_max_seq=opt['cap_max_len'],
        cms_max_seq=cms_text_length,
        tgt_emb_prj_weight_sharing=True,
        vis_emb=opt['dim_vis_feat'],
        rnn_layers=opt['rnn_layer'],
        d_k=opt['dim_head'],
        d_v=opt['dim_head'],
        d_model=opt['dim_model'],
        d_word_vec=opt['dim_word'],
        d_inner=opt['dim_inner'],
        n_layers=opt['num_layer'],
        n_head=opt['num_head'],
        dropout=opt['dropout'])

    # number of parameters
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('number of learnable parameters are {}'.format(params))

    if opt['cuda']: model = model.cuda()

    # resume from previous checkpoint if indicated
    if opt['load_checkpoint'] and opt['resume']:
        cap_state_dict = torch.load(opt['load_checkpoint'])
        model_dict = model.state_dict()
        model_dict.update(cap_state_dict)
        model.load_state_dict(model_dict)

    optimizer = ScheduledOptim(optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
                                          betas=(0.9, 0.98), eps=1e-09), 512, opt['warm_up_steps'])

    # note: though we set the init learning rate as np.power(d_model, -0.5),
    # grid search indicates different LR may improve the results.
    opt['init_lr'] = round(optimizer.init_lr, 3)

    # create checkpoint output directory
    dir = os.path.join(opt['checkpoint_path'], 'CMS_CAP_MODEL_INT_lr_{}_BS_{}_Layer_{}_ATTHEAD_{}_HID_{}_RNNLayer_{}'
                       .format(opt['init_lr'], opt['batch_size'], opt['num_layer'],
                               opt['num_head'], opt['dim_model'], opt['rnn_layer']))

    if not os.path.exists(dir): os.makedirs(dir)

    # save the model snapshot to local
    info_path = os.path.join(dir, 'iteration_info_log.log')
    print('model architecture saved to {} \n {}'.format(info_path, str(model)))
    with open(info_path, 'a') as f:
        f.write(str(model))
        f.write('\n')
        f.write(str(params))
        f.write('\n')

    # log file directory
    opt['output_dir'] = dir
    opt['info_path'] = info_path
    opt['model_info_path'] = os.path.join(opt['output_dir'],
                                          'checkpoint_loss_log.log')

    train(dataloader, model, optimizer, opt, dataset.get_cap_vocab(), dataset.get_cms_vocab())
예제 #5
0
                      train=False,
                      test=False,
                      valid=True)
trainloader = DataLoaderX(trainset, batch_size=64, shuffle=True, num_workers=8)
valloader = DataLoaderX(valset, batch_size=64, shuffle=False, num_workers=8)

resnet = resnet50(pretrained=True)
resnet.fc = nn.Linear(2048, 68)
resnet.load_state_dict(
    torch.load(
        '/home/hh9665/Desktop/CurrentProject/AutoEncoderMPNN/ckp/model_resnet50_my_loss1.pth'
    ))
model = Model(resnet, AutoEncoder, MPNN, TRAIN=True, feature_d=16)
print(model)
# model.load_state_dict(torch.load('/home/hh9665/Desktop/CurrentProject/AutoEncoderMPNN/result/T20200224_0/model0.pth'))
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005, weight_decay=0)
scheduler = StepLR(optimizer, step_size=3, gamma=0.9)
N_epoch = 80

RESULTS = {
    'LOSS': [],
    'BCELOSS': [],
    'AUTOENCODER_LOSS': [],
    'ACC': [],
    'mAP': [],
    'miF1': [],
    'maF1': [],
    'LOSS_val': [],
    'ACC_val': [],
    'mAP_val': [],
예제 #6
0
class Solver(object):
    def __init__(self, train_loader, test_loader, config, save_fold=None):

        self.train_loader = train_loader
        self.test_loader = test_loader
        self.config = config
        self.save_fold = save_fold

        self.build_model()

        if config.mode == 'test':
            self.net_bone.eval()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()  # 返回一个tensor变量内所有元素个数
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # build the network
    def build_model(self):
        print('mode: {}'.format(self.config.mode))
        print('------------------------------------------')
        self.net_bone = Model(3, self.config.mode)
        if self.config.cuda:
            self.net_bone = self.net_bone.cuda()

        if self.config.mode == 'train':
            if self.config.model_path != '':
                assert (os.path.exists(self.config.model_path)), (
                    'please import correct pretrained model path!')
                self.net_bone.load_pretrain_model(self.config.model_path)
        else:
            assert (self.config.model_path !=
                    ''), ('Test mode, please import pretrained model path!')
            assert (os.path.exists(self.config.model_path)), (
                'please import correct pretrained model path!')
            self.net_bone.load_pretrain_model(self.config.model_path)

        self.lr_bone = p['lr_bone']
        self.lr_branch = p['lr_branch']
        self.optimizer_bone = Adam(filter(lambda p: p.requires_grad,
                                          self.net_bone.parameters()),
                                   lr=self.lr_bone,
                                   weight_decay=p['wd'])
        print('------------------------------------------')
        self.print_network(self.net_bone, 'DSNet')
        print('------------------------------------------')

    def test(self):

        if not os.path.exists(self.save_fold):
            os.makedirs(self.save_fold)
        for i, data_batch in enumerate(self.test_loader):
            image, flow, name, split, size = data_batch['image'], data_batch[
                'flow'], data_batch['name'], data_batch['split'], data_batch[
                    'size']
            dataset = data_batch['dataset']

            if self.config.cuda:
                image, flow = image.cuda(), flow.cuda()
            with torch.no_grad():

                pre, pre2, pre3, pre4 = self.net_bone(image, flow)

                for i in range(self.config.test_batch_size):
                    presavefold = os.path.join(self.save_fold, dataset[i],
                                               split[i])
                    if not os.path.exists(presavefold):
                        os.makedirs(presavefold)
                    pre1 = torch.nn.Sigmoid()(pre[i])
                    pre1 = (pre1 - torch.min(pre1)) / (torch.max(pre1) -
                                                       torch.min(pre1))
                    pre1 = np.squeeze(pre1.cpu().data.numpy()) * 255
                    pre1 = cv2.resize(pre1, (size[0][1], size[0][0]))
                    cv2.imwrite(presavefold + '/' + name[i], pre1)

    def train(self):

        # 一个epoch中训练iter_num个batch
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        aveGrad = 0
        if not os.path.exists(tmp_path):
            os.mkdir(tmp_path)
        for epoch in range(self.config.epoch):
            r_img_loss, r_flo_loss, r_pre_loss, r_sal_loss, r_sum_loss = 0, 0, 0, 0, 0
            self.net_bone.zero_grad()
            for i, data_batch in enumerate(self.train_loader):
                image, label, flow = data_batch['image'], data_batch[
                    'label'], data_batch['flow']
                if image.size()[2:] != label.size()[2:]:
                    print("Skip this batch")
                    continue
                if self.config.cuda:
                    image, label, flow = image.cuda(), label.cuda(), flow.cuda(
                    )

                sal_loss1 = []
                sal_loss2 = []
                sal_loss3 = []
                sal_loss4 = []

                pre1, pre2, pre3, pre4 = self.net_bone(image, flow)

                sal_loss1.append(
                    F.binary_cross_entropy_with_logits(pre1,
                                                       label,
                                                       reduction='sum'))
                sal_loss2.append(
                    F.binary_cross_entropy_with_logits(pre2,
                                                       label,
                                                       reduction='sum'))
                sal_loss3.append(
                    F.binary_cross_entropy_with_logits(pre3,
                                                       label,
                                                       reduction='sum'))
                sal_loss4.append(
                    F.binary_cross_entropy_with_logits(pre4,
                                                       label,
                                                       reduction='sum'))
                sal_img = sum(sal_loss3) / (nAveGrad * self.config.batch_size)
                sal_flo = sum(sal_loss4) / (nAveGrad * self.config.batch_size)
                sal_pre = sum(sal_loss2) / (nAveGrad * self.config.batch_size)
                sal_final = sum(sal_loss1) / (nAveGrad *
                                              self.config.batch_size)

                r_img_loss += sal_img.data
                r_flo_loss += sal_flo.data
                r_pre_loss += sal_pre.data
                r_sal_loss += sal_final.data

                sal_loss = (sum(sal_loss1) + sum(sal_loss2) + sum(sal_loss3) +
                            sum(sal_loss4)) / (nAveGrad *
                                               self.config.batch_size)
                r_sum_loss += sal_loss.data
                loss = sal_loss
                loss.backward()
                aveGrad += 1

                if aveGrad % nAveGrad == 0:
                    self.optimizer_bone.step()
                    self.optimizer_bone.zero_grad()
                    aveGrad = 0

                if i % showEvery == 0:
                    print(
                        'epoch: [%2d/%2d], iter: [%5d/%5d]  Loss ||  img : %10.4f  ||  flo : %10.4f ||  pre : %10.4f || sal : %10.4f || sum : %10.4f'
                        % (epoch, self.config.epoch, i, iter_num, r_img_loss *
                           (nAveGrad * self.config.batch_size) / showEvery,
                           r_flo_loss * (nAveGrad * self.config.batch_size) /
                           showEvery, r_pre_loss *
                           (nAveGrad * self.config.batch_size) / showEvery,
                           r_sal_loss * (nAveGrad * self.config.batch_size) /
                           showEvery, r_sum_loss *
                           (nAveGrad * self.config.batch_size) / showEvery))

                    print('Learning rate: ' + str(self.lr_bone))
                    r_img_loss, r_flo_loss, r_pre_loss, r_sal_loss, r_sum_loss = 0, 0, 0, 0, 0

                if i % 50 == 0:
                    vutils.save_image(torch.sigmoid(pre1.data),
                                      tmp_path + '/iter%d-sal-0.jpg' % i,
                                      normalize=True,
                                      padding=0)
                    # vutils.save_image(torch.sigmoid(edge_out.data), tmp_path + '/iter%d-edge-0.jpg' % i,
                    #                   normalize=True, padding=0)
                    vutils.save_image(image.data,
                                      tmp_path + '/iter%d-sal-data.jpg' % i,
                                      padding=0)
                    vutils.save_image(label.data,
                                      tmp_path + '/iter%d-sal-target.jpg' % i,
                                      padding=0)

            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net_bone.state_dict(), '%s/models/epoch_%d_bone.pth' %
                    (self.config.save_fold, epoch + 1))

            if epoch in lr_decay_epoch:
                self.lr_bone = self.lr_bone * 0.2
                self.optimizer_bone = Adam(filter(lambda p: p.requires_grad,
                                                  self.net_bone.parameters()),
                                           lr=self.lr_bone,
                                           weight_decay=p['wd'])

        torch.save(self.net_bone.state_dict(),
                   '%s/models/final_bone.pth' % self.config.save_fold)
예제 #7
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    TVT, TMO = set_devices(cfg.sys_device_ids)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    test_set = create_dataset(**cfg.test_set_kwargs)

    #########
    # Model #
    #########

    model = Model(cfg.net,
                  path_to_predefined='',
                  pretrained=False,
                  last_conv_stride=cfg.last_conv_stride)
    model.cuda()
    r'''
  This is compeletly useless, but since I used apex, and its optimization level
  has different effect on each layer of networ and optimizer is mandatory argument, I created this optimizer.
  '''
    optimizer = optim.Adam(model.parameters())

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=cfg.opt_level,
        #loss_scale=cfg.loss_scale
    )

    print(model)

    # Model wrapper
    model_w = DataParallel(model)

    # May Transfer Model to Specified Device.
    TMO([model])

    #####################
    # Load Model Weight #
    #####################

    # To first load weights to CPU
    map_location = (lambda storage, loc: storage)
    used_file = cfg.model_weight_file or cfg.ckpt_file
    loaded = torch.load(used_file, map_location=map_location)
    if cfg.model_weight_file == '':
        loaded = loaded['state_dicts'][0]
    load_state_dict(model, loaded)
    print('Loaded model weights from {}'.format(used_file))

    ###################
    # Extract Feature #
    ###################

    test_set.set_feat_func(ExtractFeature(model_w, TVT))

    with measure_time('Extracting feature...', verbose=True):
        feat, ids, cams, im_names, marks = test_set.extract_feat(True,
                                                                 verbose=True)

    #######################
    # Select Query Images #
    #######################

    # Fix some query images, so that the visualization for different models can
    # be compared.

    # Sort in the order of image names
    inds = np.argsort(im_names)
    feat, ids, cams, im_names, marks = \
      feat[inds], ids[inds], cams[inds], im_names[inds], marks[inds]

    # query, gallery index mask
    is_q = marks == 0
    is_g = marks == 1

    prng = np.random.RandomState(2)
    # selected query indices
    sel_q_inds = prng.permutation(range(np.sum(is_q)))[:cfg.num_queries]

    q_ids = ids[is_q][sel_q_inds]
    q_cams = cams[is_q][sel_q_inds]
    q_feat = feat[is_q][sel_q_inds]
    q_im_names = im_names[is_q][sel_q_inds]

    ####################
    # Compute Distance #
    ####################

    # query-gallery distance
    q_g_dist = compute_dist(q_feat, feat[is_g], type='euclidean')

    ###########################
    # Save Rank List as Image #
    ###########################

    q_im_paths = list()
    for n in q_im_names:
        if isinstance(n, bytes):
            n = n.decode("utf-8")
        q_im_paths.append(ospj(test_set.im_dir, n))

    save_paths = list()
    for n in q_im_names:
        if isinstance(n, bytes):
            n = n.decode("utf-8")
        save_paths.append(ospj(cfg.exp_dir, 'rank_lists', n))

    g_im_paths = list()
    for n in im_names[is_g]:
        if isinstance(n, bytes):
            n = n.decode("utf-8")
        g_im_paths.append(ospj(test_set.im_dir, n))

    for dist_vec, q_id, q_cam, q_im_path, save_path in zip(
            q_g_dist, q_ids, q_cams, q_im_paths, save_paths):

        rank_list, same_id = get_rank_list(dist_vec, q_id, q_cam, ids[is_g],
                                           cams[is_g], cfg.rank_list_size)

        save_rank_list_to_im(rank_list, same_id, q_im_path, g_im_paths,
                             save_path)
예제 #8
0
def main():
  cfg = Config()

  # Redirect logs to both console and file.
  if cfg.log_to_file:
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

  # Lazily create SummaryWriter
  writer = None

  TVT, TMO = set_devices(cfg.sys_device_ids)

  if cfg.seed is not None:
    set_seed(cfg.seed)

  # Dump the configurations to log.
  import pprint
  print('-' * 60)
  print('cfg.__dict__')
  pprint.pprint(cfg.__dict__)
  print('-' * 60)

  ###########
  # Dataset #
  ###########

  if not cfg.only_test:
    train_set = create_dataset(**cfg.train_set_kwargs)
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.dataset == 'combined' else create_dataset(**cfg.val_set_kwargs)

  test_sets = []
  test_set_names = []
  if cfg.dataset == 'combined':
    for name in ['market1501', 'cuhk03', 'duke']:
      cfg.test_set_kwargs['name'] = name
      test_sets.append(create_dataset(**cfg.test_set_kwargs))
      test_set_names.append(name)
  else:
    test_sets.append(create_dataset(**cfg.test_set_kwargs))
    test_set_names.append(cfg.dataset)

  ###########
  # Models  #
  ###########

  if cfg.only_test:
    model = Model(cfg.net, pretrained=False, last_conv_stride=cfg.last_conv_stride)
  else:
    model = Model(cfg.net, path_to_predefined=cfg.net_pretrained_path, last_conv_stride=cfg.last_conv_stride) # This is a ShuffleNet Network. Model(last_conv_stride=cfg.last_conv_stride)
  
  #############################
  # Criteria and Optimizers   #
  #############################

  tri_loss = TripletLoss(margin=cfg.margin)

  optimizer = optim.Adam(model.parameters(),
                         lr=cfg.base_lr,
                         weight_decay=cfg.weight_decay)
  #optimizer = optimizers.FusedAdam(model.parameters(),
  #                      lr=cfg.base_lr,
  #                      weight_decay=cfg.weight_decay)

  #optimizer = torch.optim.SGD(model.parameters(), cfg.base_lr,
  #                            nesterov=True,
  #                            momentum=cfg.momentum,
  #                            weight_decay=cfg.weight_decay)

  model.cuda()
  model, optimizer = amp.initialize(model, optimizer,
                                    opt_level=cfg.opt_level,
                                    keep_batchnorm_fp32=cfg.keep_batchnorm_fp32,
                                    #loss_scale=cfg.loss_scale
                                    )


  amp.init() # Register function

  # Bind them together just to save some codes in the following usage.
  modules_optims = [model, optimizer]

# Model wrapper
  model_w = DataParallel(model)


  ################################
  # May Resume Models and Optims #
  ################################

  if cfg.resume:
    resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

  # May Transfer Models and Optims to Specified Device. Transferring optimizer
  # is to cope with the case when you load the checkpoint to a new device.
  TMO(modules_optims)

  ########
  # Test #
  ########

  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    for test_set, name in zip(test_sets, test_set_names):
      feature_map = ExtractFeature(model_w, TVT)
      test_set.set_feat_func(feature_map)
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        verbose=True)

  def validate():
    if val_set.extract_feat_func is None:
      feature_map = ExtractFeature(model_w, TVT)
      val_set.set_feat_func(feature_map)
    print('\n=========> Test on validation set <=========\n')
    mAP, cmc_scores, _, _ = val_set.eval(
      normalize_feat=cfg.normalize_feature,
      to_re_rank=False,
      verbose=False)
    print()
    return mAP, cmc_scores[0]

  if cfg.only_test:
    test(load_model_weight=True)
    return

  ############
  # Training #
  ############

  start_ep = resume_ep if cfg.resume else 0

  for ep in range(start_ep, cfg.total_epochs):

    # Adjust Learning Rate
    if cfg.lr_decay_type == 'exp':
      adjust_lr_exp(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.total_epochs,
        cfg.exp_decay_at_epoch)
    else:
      adjust_lr_staircase(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.staircase_decay_at_epochs,
        cfg.staircase_decay_multiply_factor)

    may_set_mode(modules_optims, 'train')

    # For recording precision, satisfying margin, etc
    prec_meter = AverageMeter()
    sm_meter = AverageMeter()
    dist_ap_meter = AverageMeter()
    dist_an_meter = AverageMeter()
    loss_meter = AverageMeter()

    ep_st = time.time()
    step = 0
    epoch_done = False
    while not epoch_done:

      step += 1
      step_st = time.time()

      ims, im_names, labels, mirrored, epoch_done = train_set.next_batch()

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))

      labels_t = TVT(torch.from_numpy(labels).long())

      feat = model_w(ims_var)

      loss, p_inds, n_inds, dist_ap, dist_an, dist_mat = global_loss(
        tri_loss, feat, labels_t,
        normalize_feature=cfg.normalize_feature)

      optimizer.zero_grad()
      
      with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
      
      optimizer.step()

      ############
      # Step Log #
      ############

      # precision
      prec = (dist_an > dist_ap).data.float().mean()
      # the proportion of triplets that satisfy margin
      sm = (dist_an > dist_ap + cfg.margin).data.float().mean()
      # average (anchor, positive) distance
      d_ap = dist_ap.data.mean()
      # average (anchor, negative) distance
      d_an = dist_an.data.mean()

      prec_meter.update(prec)
      sm_meter.update(sm)
      dist_ap_meter.update(d_ap)
      dist_an_meter.update(d_an)
      loss_meter.update(to_scalar(loss))

      if step % cfg.steps_per_log == 0:
        time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
          step, ep + 1, time.time() - step_st, )

        tri_log = (', prec {:.2%}, sm {:.2%}, '
                   'd_ap {:.4f}, d_an {:.4f}, '
                   'loss {:.4f}'.format(
          prec_meter.val, sm_meter.val,
          dist_ap_meter.val, dist_an_meter.val,
          loss_meter.val, ))

        log = time_log + tri_log
        print(log)

    #############
    # Epoch Log #
    #############

    time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st)

    tri_log = (', prec {:.2%}, sm {:.2%}, '
               'd_ap {:.4f}, d_an {:.4f}, '
               'loss {:.4f}'.format(
      prec_meter.avg, sm_meter.avg,
      dist_ap_meter.avg, dist_an_meter.avg,
      loss_meter.avg, ))

    log = time_log + tri_log
    print(log)

    ##########################
    # Test on Validation Set #
    ##########################

    mAP, Rank1 = 0, 0
    if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
      mAP, Rank1 = validate()

    # Log to TensorBoard

    if cfg.log_to_file:
      if writer is None:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
      writer.add_scalars(
        'val scores',
        dict(mAP=mAP,
             Rank1=Rank1),
        ep)
      writer.add_scalars(
        'loss',
        dict(loss=loss_meter.avg, ),
        ep)
      writer.add_scalars(
        'precision',
        dict(precision=prec_meter.avg, ),
        ep)
      writer.add_scalars(
        'satisfy_margin',
        dict(satisfy_margin=sm_meter.avg, ),
        ep)
      writer.add_scalars(
        'average_distance',
        dict(dist_ap=dist_ap_meter.avg,
             dist_an=dist_an_meter.avg, ),
        ep)

    # save ckpt
    if cfg.log_to_file:
      save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

  ########
  # Test #
  ########

  test(load_model_weight=False)