def validate(self,
                 net,
                 priors=None,
                 criterion=None,
                 use_cuda=True,
                 tb_writer=None):
        print('start evaluation')
        priors = priors.cuda(self.cfg.GENERAL.NET_CPUS[0])
        self.reset_results()
        img_idx = 0
        _t = {'im_detect': Timer(), 'misc': Timer()}
        _t['misc'].tic()
        for batch_idx, (images, targets, extra) in enumerate(self.data_loader):
            if batch_idx % 25 == 0:
                print('processed image', img_idx)
            if use_cuda:
                images = Variable(images.cuda(), volatile=True)
                extra = extra.cuda()
            else:
                images = Variable(images, volatile=True)

            _t['im_detect'].tic()
            loc, conf = net(images, phase='eval')
            # image, cls, #box, [score, xmin, ymin, xmax, ymax]
            detections = self.detector(loc, conf, priors)
            _t['im_detect'].toc(average=False)

            # print(images, 'ssssssssbbbbbbbbbbb')
            det = detections.data
            # print(det)
            h = extra[:, 0].unsqueeze(-1).unsqueeze(-1)
            w = extra[:, 1].unsqueeze(-1).unsqueeze(-1)
            det[:, :, :, 1] *= w  # xmin
            det[:, :, :, 3] *= w  # xmax
            det[:, :, :, 2] *= h  # ymin
            det[:, :, :, 4] *= h  # ymax
            det, id = self.convert_ssd_result(det, img_idx)
            # the format is now xmin, ymin, xmax, ymax, score, image, cls, (cocoid)
            if tb_writer is not None and tb_writer.cfg['show_test_image']:
                self.visualize_box(images, targets, h, w, det, img_idx,
                                   tb_writer)  #NOTE targets is changed
            img_idx = self.post_proc(det, targets, img_idx, id)
        _t['misc'].toc(average=False)
        #print in eval.py
        if self.cfg.EVAL.ONLY_SAVE_RESULTS:
            print('model infer time', _t['im_detect'].total_time,
                  _t['misc'].total_time)

        return self.evaluate_stats(None, tb_writer)
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/removal.yaml', help='Path to the config file.')
    parser.add_argument('--output_path', type=str, default='.', help="outputs path")
    parser.add_argument("--resume",default='', action="store_true") #change to True is you need to retrain from pre-train model
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)

    # dataset set up
    dataset = My3DDataset(opts=config)
    train_loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['nThreads'])


    config['vgg_model_path'] = opts.output_path

    trainer = Models(config)
    trainer.cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/outputs/logs", model_name))
    checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
    shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory, hyperparameters=config,need_opt=True) if opts.resume else 0
    max_iter = int(config['n_ep']* len(dataset)/config['batch_size'])+1

    while True:
        for it,out_data  in enumerate(train_loader):
            for j in range(len(out_data)):
                out_data[j] = out_data[j].cuda().detach()

            trainer.update_learning_rate()
            with Timer("Elapsed time in update: %f"):
                # Main training code
                if (config['models_name'] == 'removal'):
                    trainer.gen_update(out_data,config)
                    #trainer.dis_update(Xa_out, Xb_out,Xa_mask, Xa_p,config)
                    #torch.cuda.synchronize()
            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            if (iterations ) % config['image_display_iter'] == 0:
                write_image2display(iterations, trainer, train_writer)

            # Save network weights
            if (iterations+1 ) % config['snapshot_save_iter']== 0 or iterations+1==max_iter:
                trainer.save(checkpoint_directory, iterations)
            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
Example #3
0
    def validate(self, net, priors, use_cuda=True, tb_writer=None):
        print('start evaluation')
        self.reset_results()
        img_idx = 0
        _t = {'im_detect': Timer(), 'misc': Timer()}
        _t['misc'].tic()
        for batch_idx, (images, targets, extra) in enumerate(self.data_loader):
            # print('processed image', img_idx)
            if use_cuda:
                images = Variable(images.cuda(), volatile=True)
                extra = extra.cuda()
            else:
                images = Variable(images, volatile=True)

            _t['im_detect'].tic()
            loc, conf = net(images, phase='eval')
            # image, cls, #box, [score, xmin, ymin, xmax, ymax]
            detections = self.detector(loc, conf, priors)
            _t['im_detect'].toc(average=False)

            det = detections.data
            h = extra[:, 0].unsqueeze(-1).unsqueeze(-1)
            w = extra[:, 1].unsqueeze(-1).unsqueeze(-1)
            det[:, :, :, 1] *= w  # xmin
            det[:, :, :, 3] *= w  # xmax
            det[:, :, :, 2] *= h  # ymin
            det[:, :, :, 4] *= h  # ymax
            det, id = self.convert_ssd_result(det, img_idx)
            # the format is now xmin, ymin, xmax, ymax, score, image, cls, (cocoid)
            if tb_writer is not None and tb_writer.cfg['show_test_image']:
                self.visualize_box(images, targets, h, w, det, img_idx,
                                   tb_writer)
            img_idx = self.post_proc(det, img_idx, id)
            print('batch{}: {}s'.format(batch_idx, _t['im_detect'].diff))

        _t['misc'].toc(average=False)
        print(_t['im_detect'].total_time, _t['misc'].total_time)
        return self.evaluate_stats(None, tb_writer)
Example #4
0
    def validate(self, net, use_cuda=True, tb_writer=None):
        print('start evaluation')
        #         self.reset_results()
        self.preds = []
        self.gts = []
        img_idx = 0
        _t = {'im_detect': Timer(), 'misc': Timer()}
        _t['misc'].tic()
        sigout = nn.Sigmoid()

        for batch_idx, (images, targets, extra) in enumerate(self.data_loader):
            if batch_idx % 25 == 0:
                print('processed image batch', batch_idx)
            if use_cuda:
                images = Variable(images.cuda(), volatile=True)
                extra = extra.cuda()
            else:
                images = Variable(images, volatile=True)

            _t['im_detect'].tic()
            preds = sigout(net(images))
            _t['im_detect'].toc(average=False)

            self.preds += preds.data.tolist()
            for gt in targets:
                self.gts.append(gt.numpy().tolist())
#             self.gts += targets

# the format is now xmin, ymin, xmax, ymax, score, image, cls, (cocoid)
            if tb_writer is not None and tb_writer.cfg['show_test_image']:
                self.visualize_box(images, targets, h, w, det, img_idx,
                                   tb_writer)
#             img_idx = self.post_proc(preds,targets)

        _t['misc'].toc(average=False)
        # print(_t['im_detect'].total_time, _t['misc'].total_time)
        return self.evaluate_stats(None, tb_writer)
def train():
    tb_writer, cfg_path, snapshot_dir, log_dir = setup_folder(args, cfg)
    step_index = 0

    train_loader = dataset_factory(phase='train', cfg=cfg)
    val_loader = dataset_factory(phase='eval', cfg=cfg)
    eval_solver = eval_solver_factory(val_loader, cfg)

    ssd_net, priors, _ = model_factory(phase='train', cfg=cfg, tb_writer=tb_writer)
    net = ssd_net  # net is the parallel version of ssd_net
    print(net)
    print(cfg.TRAIN.OPTIMIZER)
    # return

    if args.cuda:
        net = torch.nn.DataParallel(ssd_net)
        priors = Variable(priors.cuda(), volatile=True)
    else:
        priors = Variable(priors)

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_iter = checkpoint['iteration']
        step_index = checkpoint['step_index']
        ssd_net.load_state_dict(checkpoint['state_dict'])
    else:
        # pretained weights
        pretrained_weights = torch.load(osp.join(cfg.GENERAL.WEIGHTS_ROOT, args.basenet))
        print('Loading base network...')

        try:
            ssd_net.base.load_state_dict(pretrained_weights)
        except:
            model_dict = ssd_net.base.state_dict()
            pretrained_weights = {k: v for k,
                                           v in pretrained_weights.items() if k in model_dict}
            model_dict.update(pretrained_weights)
            ssd_net.base.load_state_dict(model_dict)

        # initialize newly added layers' weights with xavier method
        print('Initializing weights...')
        ssd_net.extras.apply(weights_init)
        ssd_net.loc.apply(weights_init)
        ssd_net.conf.apply(weights_init)

    if args.cuda:
        net = net.cuda()

    optimizer = optim.SGD(net.parameters(), lr=cfg.TRAIN.OPTIMIZER.LR,
                          momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM,
                          weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY)
    criterion = MultiBoxLoss(cfg.MODEL.NUM_CLASSES, 0.5, True, 0, True, 3, 0.5,
                             False, args.cuda)

    # continue training at 8w, 12w...
    if args.start_iter not in cfg.TRAIN.LR_SCHEDULER.STEPS and step_index != 0:
        adjust_learning_rate(optimizer, cfg.TRAIN.LR_SCHEDULER.GAMMA, step_index)

    net.train()
    epoch_size = len(train_loader.dataset) // cfg.DATASET.TRAIN_BATCH_SIZE
    num_epochs = (cfg.TRAIN.MAX_ITER + epoch_size - 1) // epoch_size
    print('Training SSD on:', train_loader.dataset.name)
    print('Using the specified args:')
    print(args)

    # timer
    t_ = {'network': Timer(), 'misc': Timer(), 'all': Timer(), 'eval': Timer()}
    t_['all'].tic()

    iteration = args.start_iter
    for epoch in range(num_epochs):
        tb_writer.cfg['epoch'] = epoch
        for images, targets, _ in train_loader:
            tb_writer.cfg['iteration'] = iteration
            # t_['misc'].tic()
            # if iteration in cfg.TRAIN.LR_SCHEDULER.STEPS:
            #     t_['misc'].tic()
            #     step_index += 1
            #     adjust_learning_rate(optimizer, cfg.TRAIN.LR_SCHEDULER.GAMMA, step_index)
            #
            # if args.cuda:
            #     images = Variable(images.cuda())
            #     targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
            # else:
            #     images = Variable(images)
            #     targets = [Variable(ann, volatile=True) for ann in targets]

            # # forward
            # t_['network'].tic()
            # out = net(images)
            # out1 = [out[0], out[1], priors]
            #
            # # backward
            # optimizer.zero_grad()
            # loss_l, loss_c = criterion(out1, targets)
            # loss = loss_l + loss_c
            # loss.backward()
            # optimizer.step()
            # t_['network'].toc()
            #
            # # log
            # if iteration % cfg.TRAIN.LOG_LOSS_ITER == 0:
            #     t_['misc'].toc()
            #     print('Iter ' + str(iteration) + ' || Loss: %.3f' % (loss.data[0]) +
            #           '|| conf_loss: %.3f' % (loss_c.data[0]) + ' || loc loss: %.3f ' % (loss_l.data[0]), end=' ')
            #     print('Timer: %.3f sec.' % t_['misc'].diff, '  Lr: %.6f' % optimizer.param_groups[0]['lr'])
            #     if args.tensorboard:
            #         phase = tb_writer.cfg['phase']
            #         tb_writer.writer.add_scalar('{}/loc_loss'.format(phase), loss_l.data[0], iteration)
            #         tb_writer.writer.add_scalar('{}/conf_loss'.format(phase), loss_c.data[0], iteration)
            #         tb_writer.writer.add_scalar('{}/all_loss'.format(phase), loss.data[0], iteration)
            #         tb_writer.writer.add_scalar('{}/time'.format(phase), t_['misc'].diff, iteration)
            #
            # # save model
            # if iteration % cfg.TRAIN.SAVE_ITER == 0 and iteration != args.start_iter or \
            #         iteration == cfg.TRAIN.MAX_ITER:
            #     print('Saving state, iter:', iteration)
            #     save_checkpoint({'iteration': iteration,
            #                      'step_index': step_index,
            #                      'state_dict': ssd_net.state_dict()},
            #                     snapshot_dir,
            #                     args.cfg_name + '_' + repr(iteration) + '.pth')

            # Eval
            if (iteration % cfg.TRAIN.EVAL_ITER == 0 ) or \
                    iteration == cfg.TRAIN.MAX_ITER:
                print('Start evaluation ......')
                tb_writer.cfg['phase'] = 'eval'
                t_['eval'].tic()
                net.eval()
                aps, mAPs = eval_solver.validate(net, priors, tb_writer=tb_writer)
                net.train()
                t_['eval'].toc()
                print('Iteration ' + str(iteration) + ' || mAP: %.3f' % mAPs[0] + ' ||eval_time: %.4f/%.4f' %
                      (t_['eval'].diff, t_['eval'].average_time))
                if cfg.DATASET.NAME == 'VOC0712':
                    tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[0], iteration)
                else:
                    tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[0], iteration)
                    tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[1], iteration)
                tb_writer.cfg['phase'] = 'train'
                return

            if iteration == cfg.TRAIN.MAX_ITER:
                break
            iteration += 1

    backup_jobs(cfg, cfg_path, log_dir)
Example #6
0
def train():
    tb_writer, cfg_path, snapshot_dir, log_dir = setup_folder(args, cfg)
    step_index = 0

    train_loader = dataset_factory(phase='train', cfg=cfg)
    val_loader = dataset_factory(phase='eval', cfg=cfg)
    eval_solver = eval_solver_factory(val_loader, cfg)

    ssd_net, priors, _ = model_factory(phase='train', cfg=cfg)
    net = ssd_net  # net is the parallel version of ssd_net
    print(net)

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_iter = checkpoint['iteration'] + 1
        step_index = checkpoint['step_index']
        ssd_net.load_state_dict(checkpoint['state_dict'])
    elif cfg.MODEL.PRETRAIN_MODEL != '':
        # pretained weights
        pretrain_weights = torch.load(cfg.MODEL.PRETRAIN_MODEL)
        if 'reducedfc' not in cfg.MODEL.PRETRAIN_MODEL:
            ssd_net.apply(weights_init)
            try:
                ssd_net.load_state_dict(pretrain_weights['state_dict'],
                                        strict=False)
            except RuntimeError:  # another dataset
                entries = [
                    i for i in pretrain_weights['state_dict'].keys()
                    if i.startswith('conf')
                ]
                for key in entries:
                    del pretrain_weights['state_dict'][key]
                ssd_net.load_state_dict(pretrain_weights['state_dict'],
                                        strict=False)
        else:
            print('Loading base network...')
            ssd_net.base.load_state_dict(pretrain_weights)

            # initialize newly added layers' weights with xavier method
            print('Initializing weights...')
            ssd_net.extras.apply(weights_init)
            ssd_net.loc.apply(weights_init)
            ssd_net.conf.apply(weights_init)
    else:
        print('Initializing weights...')
        ssd_net.apply(weights_init)
        ssd_net.extras.apply(weights_init)
        ssd_net.loc.apply(weights_init)
        ssd_net.conf.apply(weights_init)

    optimizer = optim.SGD(net.parameters(),
                          lr=cfg.TRAIN.OPTIMIZER.LR,
                          momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM,
                          weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY)

    if args.cuda:
        net = torch.nn.DataParallel(ssd_net, device_ids=cfg.GENERAL.NET_CPUS)
        priors = Variable(priors.cuda(cfg.GENERAL.LOSS_GPU),
                          requires_grad=False)
        net = net.cuda()
    else:
        priors = Variable(priors, requires_grad=False)

    ssd_net.priors = priors
    ssd_net.criterion = DetectLoss(cfg)
    criterion_post = DetectLossPost(cfg)
    net.train()

    print('Using the specified args: \n', args)

    epoch_size = len(train_loader.dataset) // cfg.DATASET.TRAIN_BATCH_SIZE
    num_epochs = (cfg.TRAIN.MAX_ITER + epoch_size - 1) // epoch_size
    iteration = args.start_iter
    start_epoch = int(iteration * 1.0 / epoch_size)
    # continue training at 8w, 12w...
    if step_index > 0:
        adjust_learning_rate(optimizer, cfg.TRAIN.OPTIMIZER.LR,
                             cfg.TRAIN.LR_SCHEDULER.GAMMA, 100, step_index,
                             None, None)

    # timer
    t_ = {'network': Timer(), 'misc': Timer(), 'eval': Timer()}
    t_['misc'].tic()
    iteration = args.start_iter
    for epoch in range(start_epoch, num_epochs):
        for images, targets, _ in train_loader:
            if iteration in cfg.TRAIN.LR_SCHEDULER.STEPS or (
                    iteration <= cfg.TRAIN.WARMUP_EPOCH * epoch_size):
                if iteration in cfg.TRAIN.LR_SCHEDULER.STEPS: step_index += 1
                adjust_learning_rate(optimizer, cfg.TRAIN.OPTIMIZER.LR,
                                     cfg.TRAIN.LR_SCHEDULER.GAMMA, epoch,
                                     step_index, iteration, epoch_size,
                                     cfg.TRAIN.WARMUP_EPOCH)

            # save model
            if iteration % cfg.TRAIN.SAVE_ITER == 0 and iteration != args.start_iter or \
                    iteration == cfg.TRAIN.MAX_ITER:
                print('Saving state, iter:', iteration)
                save_checkpoint(
                    {
                        'iteration': iteration,
                        'step_index': step_index,
                        'state_dict': ssd_net.state_dict()
                    }, snapshot_dir,
                    args.cfg_name + '_' + repr(iteration) + '.pth')
            # Eval
            if iteration % cfg.TRAIN.EVAL_ITER == 0 or iteration == cfg.TRAIN.MAX_ITER:
                t_['eval'].tic()
                net.eval()
                aps, mAPs = eval_solver.validate(net,
                                                 priors,
                                                 tb_writer=tb_writer)
                net.train()
                t_['eval'].toc()
                print('Iteration ' + str(iteration) +
                      ' || mAP: %.3f' % mAPs[0] + ' ||eval_time: %.4f/%.4f' %
                      (t_['eval'].diff, t_['eval'].average_time))
                if tb_writer is not None:
                    if cfg.DATASET.NAME == 'VOC0712' or 'FACE':
                        tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[0],
                                                    iteration)
                    else:
                        tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[0],
                                                    iteration)
                        tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[1],
                                                    iteration)

                if iteration == cfg.TRAIN.MAX_ITER:
                    break

            if args.cuda:
                images = Variable(images.cuda(), requires_grad=False)
                targets = [
                    Variable(ann.cuda(cfg.GENERAL.LOSS_GPU), volatile=True)
                    for ann in targets
                ]
            else:
                images = Variable(images)
                targets = [Variable(ann, volatile=True) for ann in targets]

            # forward
            t_['network'].tic()

            match_result = matching(targets,
                                    priors,
                                    cfg.LOSS.OVERLAP_THRESHOLD,
                                    cfg.MODEL.VARIANCE,
                                    args.cuda,
                                    cfg.GENERAL.LOSS_GPU,
                                    cfg=cfg)
            net_outputs = net(images,
                              match_result=match_result,
                              tb_writer=tb_writer)
            loss, (loss_l, loss_c) = criterion_post(net_outputs)

            loss_str = ' || Loss: %.3f' % (
                loss.data[0]) + '|| conf_loss: %.3f' % (
                    loss_c) + ' || loc_loss: %.3f ' % (loss_l)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t_['network'].toc()
            t_['misc'].toc()
            # log
            if iteration % cfg.TRAIN.LOG_LOSS_ITER == 0:
                current_date = time.strftime("%Y-%m-%d %H:%M:%S",
                                             time.localtime())

                print('Iter ' + str(iteration) + loss_str, end=' ')
                print(
                    'Timer: %.3f(%.3f) %.3f(%.3f) sec.' %
                    (t_['misc'].diff, t_['misc'].average_time,
                     t_['network'].diff, t_['network'].average_time),
                    'lr: %.6f' % optimizer.param_groups[0]['lr'], ' sys_time:',
                    current_date)

                if tb_writer is not None:
                    phase = 'train'
                    tb_writer.writer.add_scalar('{}/loc_loss'.format(phase),
                                                loss_l, iteration)
                    tb_writer.writer.add_scalar('{}/conf_loss'.format(phase),
                                                loss_c, iteration)
                    tb_writer.writer.add_scalar('{}/all_loss'.format(phase),
                                                loss.data[0], iteration)
                    tb_writer.writer.add_scalar('{}/time'.format(phase),
                                                t_['misc'].diff, iteration)

            iteration += 1
            t_['misc'].tic()
Example #7
0
def train():
    tb_writer, cfg_path, snapshot_dir, log_dir = setup_folder(args, cfg)
    print(cfg_path)
    step_index = 0

    train_loader = dataset_factory(phase='train', cfg=cfg)
    val_loader = dataset_factory(phase='eval', cfg=cfg)
    eval_solver = eval_solver_factory(val_loader, cfg)

    cls_net = clsmodel_factory(phase='train', cfg=cfg)
    net = cls_net  # net is the parallel version of cls_net

    print(net)
    if args.shownet:
        return;
    
    if args.cuda:
        net = torch.nn.DataParallel(cls_net, device_ids=cfg.GENERAL.NET_CPUS)

    cls_net.apply(weights_init)
    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_iter = checkpoint['iteration']
        step_index = checkpoint['step_index']
        cls_net.load_state_dict(checkpoint['state_dict'].state_dict())
    elif args.pretrain:
        # pretained weights
        print('Loading pretrained model: {}'.format(cfg.MODEL.PRETRAIN_MODEL))
        pretrain_weights = torch.load(cfg.MODEL.PRETRAIN_MODEL)
        if 'reducedfc' not in cfg.MODEL.PRETRAIN_MODEL:
            print('Loading whole network...')
            cls_net.load_state_dict(pretrain_weights['state_dict'].state_dict(), strict=False)
#             cls_net.apply(weights_init)
#             try:
#                 cls_net.load_state_dict(pretrain_weights, strict=False)
#             except RuntimeError:  # another dataset
#                 entries = [i for i in pretrain_weights['state_dict'].keys() if i.startswith('conf')]
#                 for key in entries:
#                     del pretrain_weights['state_dict'][key]
#                 cls_net.load_state_dict(pretrain_weights['state_dict'], strict=False)
        else:
            print('Loading base network...')
            cls_net.base.load_state_dict(pretrain_weights['state_dict'].state_dict(), strict=False)
    else:
        cls_net.apply(weights_init)
        print('random init net weight with xavier')

    if args.cuda:
        net = net.cuda()

    optimizer = optim.SGD(net.parameters(), lr=cfg.TRAIN.OPTIMIZER.LR,
                          momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM,
                          weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY)

    # criterion = MultiBoxLoss(cfg, args.cuda)
#     cls_net.criterion = DetectLoss(cfg)
#     criterion_post = DetectLossPost(cfg)
#     criterion_post = nn.BCEWithLogitsLoss()
    criterion_post = FocalLoss_BCE(alpha=0.8,gamma=2,num_classes=1)
#     criterion_post = FocalLoss_BCE(alpha=0.5,gamma=0,num_classes=1)

    # continue training at 8w, 12w...
    if args.start_iter not in cfg.TRAIN.LR_SCHEDULER.STEPS and step_index != 0:
        adjust_learning_rate(optimizer, cfg.TRAIN.OPTIMIZER.LR, cfg.TRAIN.LR_SCHEDULER.GAMMA, 100,
                             step_index, None, None)

    net.train()
    epoch_size = len(train_loader.dataset) // cfg.DATASET.TRAIN_BATCH_SIZE
    num_epochs = (cfg.TRAIN.MAX_ITER - args.start_iter + epoch_size - 1) // epoch_size
    print('Training SSD on:', train_loader.dataset.name)
    print('Using the specified args:')
    print(args)

    # timer
    t_ = {'network': Timer(), 'forward': Timer(), 'misc': Timer(), 'all': Timer(), 'eval': Timer()}
    t_['all'].tic()

    iteration = args.start_iter
    epoch_bias = int(iteration/epoch_size)
    for epoch in range(num_epochs):
        epoch += epoch_bias
        tb_writer.cfg['epoch'] = epoch
        for images,targets,_ in train_loader:
            tb_writer.cfg['iteration'] = iteration
            t_['misc'].tic()
            if iteration in cfg.TRAIN.LR_SCHEDULER.STEPS or \
                    (epoch < cfg.TRAIN.WARMUP_EPOCH and not args.resume):
                if epoch >= cfg.TRAIN.WARMUP_EPOCH:
                    step_index += 1
                adjust_learning_rate(optimizer, cfg.TRAIN.OPTIMIZER.LR, cfg.TRAIN.LR_SCHEDULER.GAMMA, epoch,
                                     step_index, iteration, epoch_size, cfg.TRAIN.WARMUP_EPOCH)

            # save model
            if iteration % cfg.TRAIN.SAVE_ITER == 0 and iteration != args.start_iter or \
                    iteration == cfg.TRAIN.MAX_ITER:
                print('Saving state, iter:', iteration)
                save_checkpoint({'iteration': iteration,
                                 'step_index': step_index,
                                 'state_dict': cls_net},
                                snapshot_dir,
                                args.cfg_name + '_' + repr(iteration) + '.pth')
            # Eval
            if iteration % cfg.TRAIN.EVAL_ITER == 0 or iteration == cfg.TRAIN.MAX_ITER:
                tb_writer.cfg['phase'] = 'eval'
                tb_writer.cfg['iter'] = iteration
                t_['eval'].tic()
                net.eval()
                if torch.cuda.is_available():
                    net = nn.DataParallel(net, device_ids=[0])
                aps, mAPs = eval_solver.validate(net, tb_writer=tb_writer)
                net.train()
                if torch.cuda.is_available():
                    net = torch.nn.DataParallel(cls_net, device_ids=cfg.GENERAL.NET_CPUS)
                t_['eval'].toc()
                print('Iteration ' + str(iteration) + ' || mAP: %.3f' % mAPs[0] + ' ||eval_time: %.4f/%.4f' %
                      (t_['eval'].diff, t_['eval'].average_time))
                if cfg.DATASET.NAME == 'VOC0712':
                    tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[0], iteration)
                else:
                    tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[0], iteration)
                    tb_writer.writer.add_scalar('mAP/[email protected]', mAPs[1], iteration)
                tb_writer.cfg['phase'] = 'train'

                if iteration == cfg.TRAIN.MAX_ITER:
                    break
                    
            targets = torch.stack(targets)
            if args.cuda:
                images = Variable(images.cuda(), requires_grad=False)
#                 targets = [Variable(ann.cuda(cfg.GENERAL.LOSS_GPU), volatile=True)
#                            for ann in targets]
                targets = Variable(targets.cuda(cfg.GENERAL.LOSS_GPU), requires_grad=False)
            else:
                images = Variable(images)
#                 targets = [Variable(ann, volatile=True) for ann in targets]
                targets = Variable(targets, requires_grad=False)

            # forward
            t_['network'].tic()
            t_['forward'].tic()
        
            net_outputs = net(images)
            t_['forward'].toc()
#             import pdb
#             pdb.set_trace()
            
#             net_outputs = net_outputs.view(1,-1)
#             targets = targets.view(1,-1)
            loss = criterion_post(net_outputs,targets)

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t_['network'].toc()

            # log
            if iteration % cfg.TRAIN.LOG_LOSS_ITER == 0 or iteration == 1:
                t_['misc'].toc()
                now_time = datetime.datetime.now()
                time_str = datetime.datetime.strftime(now_time,'%Y-%m-%d %H:%M:%S')
                print(time_str+'\tIter ' + str(iteration) + ' || Loss: %.3f' % (loss.data[0]), end=' ')
                print('Forward Timer: %.3f sec.' % t_['forward'].diff, '  Lr: %.6f' % optimizer.param_groups[0]['lr'])
                if args.tensorboard:
                    phase = tb_writer.cfg['phase']
                    tb_writer.writer.add_scalar('{}/all_loss'.format(phase), loss.data[0], iteration)
                    tb_writer.writer.add_scalar('{}/time'.format(phase), t_['misc'].diff, iteration)

            iteration += 1

    backup_jobs(cfg, cfg_path, log_dir)