예제 #1
0
def train():
    net.train()
    # loss counters
    loc_loss = 0  # epoch
    conf_loss = 0
    epoch = 0
    print('Loading Dataset...')

    dataset = Detection(args.annoPath, PyramidAugmentation(ssd_dim, means), AnnotationTransform())
    print('len(dataset) = ' + str(len(dataset)))
    print(dataset.__getitem__(0))
    epoch_size = len(dataset) // args.batch_size
    print('Training PyramidBox on', dataset.name)
    step_index = 0
    if args.visdom:
        # initialize visdom loss plot
        lot = viz.line(
            X=np.array(torch.zeros((1,)).cpu()),
            Y=np.array(torch.zeros((1, 3)).cpu()),
            opts=dict(
                xlabel='Iteration',
                ylabel='Loss',
                title='Current PyramidBox Training Loss',
                legend=['Loc Loss', 'Conf Loss', 'Loss']
            )
        )
        epoch_lot = viz.line(
            X=np.array(torch.zeros((1,)).cpu()),
            Y=np.array(torch.zeros((1, 3)).cpu()),
            opts=dict(
                xlabel='Epoch',
                ylabel='Loss',
                title='Epoch PyramidBox Training Loss',
                legend=['Loc Loss', 'Conf Loss', 'Loss']
            )
        )
        lr_lot = viz.line(
            X=np.array(torch.zeros((1,)).cpu()),
            Y=np.array(torch.zeros((1,1)).cpu()),
            opts=dict(
                xlabel='iteration',
                ylabel='learning-rate',
                title='Warm-up',
                legend=['lr']
            )
        )
    batch_iterator = None
    data_loader = data.DataLoader(dataset, batch_size, num_workers=args.num_workers,
                                  shuffle=True, collate_fn=detection_collate, pin_memory=True)
    print('data loading finished...')
    for iteration in range(args.start_iter, max_iter):
        t0 = time.time()
        try:
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(data_loader)
            adjust_learning_rate(optimizer, gamma, iteration)
            if iteration in stepvalues:
                step_index += 1
                if args.visdom:
                    viz.line(
                        X=np.array(torch.ones((1, 3)).cpu()) * epoch,
                        Y=np.array(torch.Tensor([loc_loss, conf_loss,
                            loc_loss + conf_loss]).unsqueeze(0).cpu() )/ epoch_size,
                        win=epoch_lot,
                        update='append'
                    )
                # reset epoch loss counters
                loc_loss = 0
                conf_loss = 0
                epoch += 1
            # load train data
            images, targets = next(batch_iterator)

            if args.cuda:
                images = Variable(images.cuda())
                targets = [Variable(anno.cuda(), volatile=True) for anno in targets]
            else:
                images = Variable(images)
                targets = [Variable(anno, volatile=True) for anno in targets]
            # forward
            t1 = time.time()
            out = net(images)
            # backprop
            optimizer.zero_grad()
            loss_l, loss_c = criterion(tuple(out[0:3]), targets)
            loss_l_head, loss_c_head = criterion(tuple(out[3:6]), targets)
        
            loss = loss_l + loss_c + 0.5 * loss_l_head + 0.5 * loss_c_head
            loss.backward()
            optimizer.step()
            t2 = time.time()
            loc_loss += loss_l.data[0]
            conf_loss += loss_c.data[0]
            if iteration % 10 == 0:
                print('front and back Timer: {} sec.' .format((t2 - t1)))
                print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]))
                print('Loss conf: {} Loss loc: {}'.format(loss_c.data[0],loss_l.data[0]))
                print('Loss head conf: {} Loss head loc: {}'.format(loss_c_head.data[0],loss_l_head.data[0]))
                print('lr: {}'.format(optimizer.param_groups[0]['lr']))
                if args.visdom and args.send_images_to_visdom:
                    random_batch_index = np.random.randint(images.size(0))
                    viz.image(images.data[random_batch_index].cpu().numpy())
            if args.visdom:
                viz.line(
                    X=np.array(torch.ones((1, 3)).cpu()) * iteration,
                    Y=np.array(torch.Tensor([loss_l.data[0], loss_c.data[0],
                        loss_l.data[0] + loss_c.data[0]]).unsqueeze(0).cpu()),
                    win=lot,
                    update='append'
                )
                viz.line(
                    X=np.array(torch.ones((1,1)).cpu()) * iteration,
                    Y=np.array(torch.Tensor([optimizer.param_groups[0]['lr']]).unsqueeze(0).cpu()),
                    win=lr_lot,
                    update='append'
                )
                # hacky fencepost solution for 0th epoch plot
                if iteration == 0:
                    viz.line(
                        X=np.array(torch.zeros((1, 3)).cpu()),
                        Y=np.array(torch.Tensor([loc_loss, conf_loss,
                            loc_loss + conf_loss]).unsqueeze(0).cpu()),
                        win=epoch_lot,
                        update=True
                    )
                    viz.line(
                        X=np.array(torch.zeros((1,1)).cpu()),
                        Y=np.array(torch.Tensor([optimizer.param_groups[0]['lr']]).unsqueeze(0).cpu()),
                        win=lr_lot,
                        update=True
                    )

        except TypeError as e:
            print(e)
            print('-'*20,'jump to next iter and log.')
            continue
        except ValueError as e2:
            print(e2)
            print('='*20,'jump to next iter and log.')
            continue
        if iteration % 5000 == 0:
            print('Saving state, iter:', iteration)
            torch.save(ssd_net.state_dict(), args.save_folder + 'Res50_pyramid_' +
                       repr(iteration) + '.pth')
    torch.save(ssd_net.state_dict(), args.save_folder + 'Res50_pyramid_' + '.pth')
def train():
    net.train()
    # loss counters
    loc_loss = 0  # epoch
    conf_loss = 0
    epoch = 0
    min_loss = float('inf')
    print('Loading Dataset...')

    dataset = Detection(args.annoPath, PyramidAugmentation(ssd_dim, means), AnnotationTransform())

    epoch_size = len(dataset) // args.batch_size
    print('Training SSD on', dataset.name)
    step_index = 0
    step_increase = 0
    if args.visdom:
        # initialize visdom loss plot
        lot = viz.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1, 3)).cpu(),
            opts=dict(
                xlabel='Iteration',
                ylabel='Loss',
                title='Current SSD Training Loss',
                legend=['Loc Loss', 'Conf Loss', 'Loss']
            )
        )
        epoch_lot = viz.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1, 3)).cpu(),
            opts=dict(
                xlabel='Epoch',
                ylabel='Loss',
                title='Epoch SSD Training Loss',
                legend=['Loc Loss', 'Conf Loss', 'Loss']
            )
        )
    batch_iterator = None
    data_loader = data.DataLoader(dataset, batch_size, num_workers=args.num_workers,
                                  shuffle=True, collate_fn=detection_collate, pin_memory=True)
    for iteration in range(args.start_iter, max_iter):
        t0 = time.time()
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(data_loader)
        if iteration in stepvalues:
            if iteration in stepvalues[0:5]:
                step_increase += 1
                warmup_learning_rate(optimizer, args.lr, step_increase)
            else:
                step_index += 1
                adjust_learning_rate(optimizer, gamma, step_index)
            if args.visdom:
                viz.line(
                    X=torch.ones((1, 3)).cpu() * epoch,
                    Y=torch.Tensor([loc_loss, conf_loss,
                                    loc_loss + conf_loss]).unsqueeze(0).cpu() / epoch_size,
                    win=epoch_lot,
                    update='append'
                )
            # reset epoch loss counters
            loc_loss = 0
            conf_loss = 0
            epoch += 1

        # load train data
        images, targets = next(batch_iterator)

        if args.cuda:
            images = Variable(images.cuda())
            targets = [Variable(anno.cuda(), volatile=True) for anno in targets]
        else:
            images = Variable(images)
            targets = [Variable(anno, volatile=True) for anno in targets]
        # forward
        t1 = time.time()
        out = net(images)
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(tuple(out[0:3]), targets)
        loss_l_head, loss_c_head = criterion(tuple(out[3:6]), targets)

        loss = loss_l + loss_c + 0.5 * loss_l_head + 0.5 * loss_c_head
        if (loss.data[0] < min_loss):
            min_loss = loss.data[0]
            print("min_loss: " , min_loss)
            torch.save(ssd_net.state_dict(), args.save_folder + 'best_our_ucsd_Res50_pyramid_aug' + '.pth')
        loss.backward()
        optimizer.step()
        t2 = time.time()
        loc_loss += loss_l.data[0]
        conf_loss += loss_c.data[0]
        if iteration % 50 == 0:
            print('front and back Timer: {} sec.'.format((t2 - t1)))
            print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]))
            print('Loss conf: {} Loss loc: {}'.format(loss_c.data[0], loss_l.data[0]))
            print('Loss head conf: {} Loss head loc: {}'.format(loss_c_head.data[0], loss_l_head.data[0]))
            print('lr: {}'.format(optimizer.param_groups[0]['lr']))
            if args.visdom and args.send_images_to_visdom:
                random_batch_index = np.random.randint(images.size(0))
                viz.image(images.data[random_batch_index].cpu().numpy())
        if args.visdom:
            viz.line(
                X=torch.ones((1, 3)).cpu() * iteration,
                Y=torch.Tensor([loss_l.data[0], loss_c.data[0],
                                loss_l.data[0] + loss_c.data[0]]).unsqueeze(0).cpu(),
                win=lot,
                update='append'
            )
            # hacky fencepost solution for 0th epoch plot
            if iteration == 0:
                viz.line(
                    X=torch.zeros((1, 3)).cpu(),
                    Y=torch.Tensor([loc_loss, conf_loss,
                                    loc_loss + conf_loss]).unsqueeze(0).cpu(),
                    win=epoch_lot,
                    update=True
                )
        if iteration % 500 == 0 or iteration in stepvalues:
            print('Saving state, iter:', iteration)
            torch.save(ssd_net.state_dict(), args.save_folder + 'our_ucsd_Res50_pyramid_aug_' +  repr(iteration) + '.pth')
    torch.save(ssd_net.state_dict(), args.save_folder + 'our_ucsd_Res50_pyramid_aug' + '.pth')
예제 #3
0
    print('Resuming training, loading {}...'.format(args.resume))
    ssd_net.load_weights(args.resume)
else:
    pass

optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=momentum,
                      weight_decay=weight_decay)
ssd_net, optimizer = amp.initialize(ssd_net, optimizer, opt_level="02")

criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 3, 0.35, False,
                         False, args.cuda)
criterion1 = MultiBoxLoss(num_classes, 0.35, True, 0, True, 3, 0.35, False,
                          True, args.cuda)
dataset = Detection(args.annoPath, PyramidAugmentation(ssd_dim, means),
                    AnnotationTransform())
if args.useMultiProcess:
    torch.distributed.init_process_group(backend=args.dist_backend,
                                         init_method="env://",
                                         world_size=args.world_size,
                                         rank=args.device_ids)
    net = torch.nn.parallel.DistributedDataParallel(
        ssd_net, device_ids=tuple(np.arange(0, args.device_ids)))
    # boost computing rate but some randomness vice versa cudnn.deterministic=True
    cudnn.benchmark = True
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)


def train():
    net.train()