def __init__(self, cfg, use_cuda):
        super(Trainer, self).__init__()

        self.cfg = cfg
        self.use_cuda = use_cuda

        # Networks

        if cfg.TRAIN.MAIN:
            basenet = basenet_factory('vgg')
            self.net, feat_out_channels = build_net('train', cfg.NUM_CLASSES,
                                                    'vgg')

        if cfg.TRAIN.ROTATION:
            self.rot_trainer = BasicTrainer(feat_out_channels, 4)

        if cfg.TRAIN.CONTRASTIVE:
            self.con_trainer = ContrastiveTrainer(self.net, feat_out_channels)

        if cfg.TRAIN.JIGSAW_33:
            self.jig_33_trainer = BasicTrainer(feat_out_channels, 30)

        if cfg.TRAIN.JIGSAW_22:
            self.jig_22_trainer = BasicTrainer(feat_out_channels, 24)

        if cfg.TRAIN.JIGSAW_41:
            self.jig_41_trainer = BasicTrainer(feat_out_channels, 24)

        if cfg.TRAIN.JIGSAW_14:
            self.jig_14_trainer = BasicTrainer(feat_out_channels, 24)

        if cfg.TRAIN.CONTRASTIVE_SOURCE:
            self.consrc_trainer = ContrastiveTrainer(self.net,
                                                     feat_out_channels)

        if cfg.TRAIN.CONTRASTIVE_TARGET:
            self.contrg_trainer = ContrastiveTrainer(self.net,
                                                     feat_out_channels)

        if cfg.TRAIN.TRANSFER_CONTRASTIVE:
            self.tsf_trainer = ContrastiveTrainer(self.net, feat_out_channels)
Beispiel #2
0
def train():
    per_epoch_size = len(train_dataset) // args.batch_size
    start_epoch = 0
    iteration = 0
    step_index = 0

    basenet = basenet_factory(args.model)
    dsfd_net = build_net('train', cfg.NUM_CLASSES, args.model)
    # net = torch.nn.DataParallel(dsfd_net)
    net = dsfd_net

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        start_epoch = net.load_weights(args.resume)
        iteration = start_epoch * per_epoch_size
    else:
        base_weights = torch.load(args.save_folder + basenet)
        print('Load base network {}'.format(args.save_folder + basenet))

        if args.model == 'vgg':
            net.load_state_dict(
                {k.replace('module.', ''): v
                 for k, v in base_weights.items()})
            # net.vgg.load_state_dict({k.replace('module.', ''): v for k, v in base_weights.items()})
        else:
            net.resnet.load_state_dict(base_weights)

    if args.cuda:
        if args.multigpu:
            net = torch.nn.DataParallel(dsfd_net)
        net = net.cuda()
        cudnn.benckmark = True

    if not args.resume:
        print('Initializing weights...')
        dsfd_net.extras.apply(dsfd_net.weights_init)
        dsfd_net.fpn_topdown.apply(dsfd_net.weights_init)
        dsfd_net.fpn_latlayer.apply(dsfd_net.weights_init)
        dsfd_net.fpn_fem.apply(dsfd_net.weights_init)
        dsfd_net.loc_pal1.apply(dsfd_net.weights_init)
        dsfd_net.conf_pal1.apply(dsfd_net.weights_init)
        dsfd_net.loc_pal2.apply(dsfd_net.weights_init)
        dsfd_net.conf_pal2.apply(dsfd_net.weights_init)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(cfg, args.cuda)
    print('Loading wider dataset...')
    print('Using the specified args:')
    print(args)

    for step in cfg.LR_STEPS:
        if iteration > step:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

    net.train()
    for epoch in range(start_epoch, cfg.EPOCHES):
        losses = 0
        for batch_idx, (images, targets) in enumerate(train_loader):
            if args.cuda:
                images = Variable(images.cuda())
                targets = [
                    Variable(ann.cuda(), volatile=True) for ann in targets
                ]
            else:
                images = Variable(images)
                targets = [Variable(ann, volatile=True) for ann in targets]

            if iteration in cfg.LR_STEPS:
                step_index += 1
                adjust_learning_rate(optimizer, args.gamma, step_index)

            t0 = time.time()
            out = net(images)
            # backprop
            optimizer.zero_grad()
            loss_l_pa1l, loss_c_pal1 = criterion(out[:3], targets)
            loss_l_pa12, loss_c_pal2 = criterion(out[3:], targets)

            loss = loss_l_pa1l + loss_c_pal1 + loss_l_pa12 + loss_c_pal2
            loss.backward()
            optimizer.step()
            t1 = time.time()
            # print(loss)
            # losses += loss.data[0]
            losses += loss.data

            if iteration % 10 == 0:
                tloss = losses / (batch_idx + 1)
                print('Timer: %.4f' % (t1 - t0))
                print('epoch:' + repr(epoch) + ' || iter:' + repr(iteration) +
                      ' || Loss:%.4f' % (tloss))
                print(
                    '->> pal1 conf loss:{:.4f} || pal1 loc loss:{:.4f}'.format(
                        # loss_c_pal1.data[0], loss_l_pa1l.data[0]))
                        loss_c_pal1.data,
                        loss_l_pa1l.data))
                print(
                    '->> pal2 conf loss:{:.4f} || pal2 loc loss:{:.4f}'.format(
                        # loss_c_pal2.data[0], loss_l_pa12.data[0]))
                        loss_c_pal2.data,
                        loss_l_pa12.data))
                print('->>lr:{}'.format(optimizer.param_groups[0]['lr']))

            if iteration != 0 and iteration % 5000 == 0:
                print('Saving state, iter:', iteration)
                file = 'dsfd_' + repr(iteration) + '.pth'
                torch.save(dsfd_net.state_dict(),
                           os.path.join(save_folder, file))
            iteration += 1

        val(epoch, net, dsfd_net, criterion)
        if iteration == cfg.MAX_STEPS:
            break
Beispiel #3
0
def main(args):

    # check for multiple gpus
    if not args.multigpu:
        os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    # check whether cuda available or not
    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " +
                "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    # create saving directory for checkpoints
    save_folder = os.path.join(args.save_folder, args.model)
    if not os.path.exists(save_folder):
        os.mkdir(save_folder)

    # define the datasets and data loaders
    train_dataset = WIDERDetection(cfg.FACE.TRAIN_FILE, mode='train')
    val_dataset = WIDERDetection(cfg.FACE.VAL_FILE, mode='val')

    train_loader = data.DataLoader(train_dataset,
                                   args.batch_size,
                                   num_workers=args.num_workers,
                                   shuffle=True,
                                   collate_fn=detection_collate,
                                   pin_memory=True)
    val_batchsize = args.batch_size // 2
    val_loader = data.DataLoader(val_dataset,
                                 val_batchsize,
                                 num_workers=args.num_workers,
                                 shuffle=False,
                                 collate_fn=detection_collate,
                                 pin_memory=True)

    min_loss = np.inf

    per_epoch_size = len(train_dataset) // args.batch_size
    start_epoch = 0
    iteration = 0
    step_index = 0

    # define the model
    basenet = basenet_factory(args.model)
    dsfd_net = build_net('train', cfg.NUM_CLASSES, args.model)
    net = dsfd_net

    # check whether to resume from a previous checkpoint or not
    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        start_epoch = net.load_weights(args.resume)
        iteration = start_epoch * per_epoch_size

        base_weights = torch.load(args.save_folder + basenet)
        print('Load base network {}'.format(args.save_folder + basenet))
        if args.model == 'vgg':
            net.vgg.load_state_dict(base_weights)
        else:
            net.resnet.load_state_dict(base_weights)

    # if cuda available and if multiple gpus available
    if args.cuda:
        if args.multigpu:
            net = torch.nn.DataParallel(dsfd_net)
        net = net.cuda()
        cudnn.benckmark = True

    # randomly initialize the model
    if not args.resume:
        print('Randomly initializing weights for the described DSFD Model...')
        dsfd_net.extras.apply(dsfd_net.weights_init)
        dsfd_net.fpn_topdown.apply(dsfd_net.weights_init)
        dsfd_net.fpn_latlayer.apply(dsfd_net.weights_init)
        dsfd_net.fpn_fem.apply(dsfd_net.weights_init)
        dsfd_net.loc_pal1.apply(dsfd_net.weights_init)
        dsfd_net.conf_pal1.apply(dsfd_net.weights_init)
        dsfd_net.loc_pal2.apply(dsfd_net.weights_init)
        dsfd_net.conf_pal2.apply(dsfd_net.weights_init)

    # define the optimizer and loss criteria
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(cfg, args.cuda)
    print('Loading wider dataset...')
    print('Using the specified args:')
    print(args)

    # taking care of the learning scheduler
    for step in cfg.LR_STEPS:
        if iteration > step:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

    for epoch in range(start_epoch, cfg.EPOCHES):
        train_loss, iteration, step_index = train(net,
                                                  criterion,
                                                  train_loader,
                                                  optimizer,
                                                  epoch,
                                                  iteration,
                                                  step_index,
                                                  args.gamma,
                                                  device=None)

        val_loss = validate(net, criterion, val_loader, epoch)

        # validation loss less than the previous one, save the better checkpoint
        if val_loss < min_loss:
            print('Saving best state,epoch', epoch)
            torch.save(dsfd_net.state_dict(),
                       os.path.join(save_folder, 'dsfd.pth'))
            min_loss = val_loss

        states = {
            'epoch': epoch,
            'weight': dsfd_net.state_dict(),
        }
        torch.save(states, os.path.join(save_folder, 'dsfd_checkpoint.pth'))
        if iteration == cfg.MAX_STEPS:
            break