Exemplo n.º 1
0
def main():
    opt = opts().parse()
    now = datetime.datetime.now()

    #logger = L

    if opt.loadModel != 'none':
        model = torch.load(opt.loadModel).cuda()
    else:
        model = AlexNet(ref.nJoints).cuda()

    criterion = torch.nn.MSELoss().cuda()
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    opt.LR,
                                    alpha=ref.alpha,
                                    eps=ref.epsilon,
                                    weight_decay=ref.weightDecay,
                                    momentum=ref.momentum)

    val_loader = torch.utils.data.DataLoader(H36M(opt, 'val'),
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=int(ref.nThreads))

    #    if opt.test:
    #        val(0, opt, val_loader, model, criterion)
    #        return

    train_loader = torch.utils.data.DataLoader(
        Fusion(opt, 'train'),
        batch_size=opt.trainBatch,
        shuffle=True,  #if opt.DEBUG == 0 else False,
        num_workers=int(ref.nThreads))

    for epoch in range(1, opt.nEpochs + 1):
        loss_train, acc_train, mpjpe_train, loss3d_train = train(
            epoch, opt, train_loader, model, criterion, optimizer)

        logger.scalar_summary('loss_train', loss_train, epoch)
        #logger.scalar_summary('acc_train', acc_train, epoch)
        #logger.scalar_summary('mpjpe_train', mpjpe_train, epoch)
        #logger.scalar_summary('loss3d_train', loss3d_train, epoch)

        if epoch % opt.valIntervals == 0:
            loss_val, acc_val, mpjpe_val, loss3d_val = val(
                epoch, opt, val_loader, model, criterion)
            logger.scalar_summary('loss_val', loss_val, epoch)
            #logger.scalar_summary('acc_val', acc_val, epoch)
            #logger.scalar_summary('mpjpe_val', mpjpe_val, epoch)
            #logger.scalar_summary('loss3d_val', loss3d_val, epoch)
            torch.save(model,
                       os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch)))
            #logger.write('{:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train, loss_val, acc_val, mpjpe_val, loss3d_val))
            logger.write('{:8f} {:8f} \n'.format(loss_train, loss_val))

        else:
            #logger.write('{:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train))
            #logger.write('{:8f} \n'.format(loss_train ) )
            adjust_learning_rate(optimizer, epoch, opt.dropLR, opt.LR)
            logger.close()
def train(epoch):
    global iter
    max_iter = config['num_epochs'] * len(trainloader)
    train_loss_1.reset()
    train_acc.reset()
    train_IOU.reset()
    train_back_IOU.reset()
    net.train()
    for idx, batch in enumerate(trainloader):
        end = time.time()
        new_lr = utils.polynomial_decay(optim_opt['lr'],
                                        iter,
                                        max_iter,
                                        power=0.9,
                                        end_learning_rate=1e-4)
        utils.adjust_learning_rate(optimizer, new_lr)
        image = batch[0].cuda()
        instance_label = batch[1].cuda()
        optimizer.zero_grad()
        prob_output = net(image)

        loss1 = F.cross_entropy(prob_output, instance_label, weight=weight)

        total_loss = loss1
        total_loss.backward()
        optimizer.step()
        ####################################
        train_loss_1.update(loss1.item())

        acc, IOU, back_IOU = utils.compute_accuracy(prob_output,
                                                    instance_label)
        train_acc.update(acc)
        train_IOU.update(IOU)
        train_back_IOU.update(back_IOU)

        if idx % config['display_step'] == 0:
            logger.info(
                '==> Iteration [{}][{}/{}][{}/{}]: loss1: {:.4f} ({:.4f})  lr:{:.4f} acc: {:.4f} ({:.4f}) IOU: {:.4f} ({:.4f}) back_IOU: {:.4f} ({:.4f}) time: {:.4f}'
                .format(
                    epoch + 1,
                    idx,
                    len(trainloader),
                    iter,
                    max_iter,
                    loss1.item(),
                    train_loss_1.avg,
                    new_lr,
                    acc,
                    train_acc.avg,
                    IOU,
                    train_IOU.avg,
                    back_IOU,
                    train_back_IOU.avg,
                    time.time() - end,
                ))

        iter += 1
Exemplo n.º 3
0
def main(args):

    # load dataset
    train_set = LineModDataset(args.data_path, args.class_type)
    test_set = LineModDataset(args.data_path,
                              args.class_type,
                              is_train=False,
                              occ=args.occ)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=12)
    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             num_workers=12)

    device = torch.device(
        "cuda:0,1,2,3" if cuda else "cpu")  # +re.split(r",",args.gpu_id)[3]
    dnn_model_dir = osp.join("model", args.class_type)
    mesh_model_dir = osp.join(args.data_path, "linemod", args.class_type,
                              "{}_new.ply".format(args.class_type))

    psgmnet = psgmn(mesh_model_dir)

    psgmnet = torch.nn.DataParallel(psgmnet, device_ids=[0, 1, 2, 3])
    psgmnet = psgmnet.to(device)
    optimizer = torch.optim.Adam(psgmnet.parameters(), lr=args.lr)

    # code for evaluation
    if args.eval:
        linemod_eval = evaluator(args, psgmnet, test_loader, device)
        load_network(psgmnet, dnn_model_dir, epoch=args.used_epoch)
        linemod_eval.evaluate()
        return

    if args.train:

        #start_epoch= 1
        start_epoch = load_network(psgmnet, dnn_model_dir) + 1
        for epoch in range(start_epoch, args.epochs + 1):
            print("current class:{}".format(args.class_type))
            adjust_learning_rate(optimizer, epoch, args.lr)

            loss = train(psgmnet, train_loader, optimizer, device)
            print(f'Epoch: {epoch:02d}, Loss: {loss*args.batch_size:.4f}')
            if epoch % 10 == 0:

                if not osp.exists(
                        osp.join(os.getcwd(), 'model', args.class_type)):
                    os.makedirs(osp.join(os.getcwd(), 'model',
                                         args.class_type))
                torch.save(
                    psgmnet.state_dict(),
                    osp.join('model', args.class_type, '{}.pkl'.format(epoch)))
Exemplo n.º 4
0
def main():
    opt = opts().parse()
    now = datetime.datetime.now()
    logger = Logger(opt.saveDir, now.isoformat())
    model, optimizer = getModel(opt)
    criterion = torch.nn.MSELoss().cuda()

    # if opt.GPU > -1:
    #     print('Using GPU {}',format(opt.GPU))
    #     model = model.cuda(opt.GPU)
    #     criterion = criterion.cuda(opt.GPU)
    # dev = opt.device
    model = model.cuda()

    val_loader = torch.utils.data.DataLoader(
            MPII(opt, 'val'), 
            batch_size = 1, 
            shuffle = False,
            num_workers = int(ref.nThreads)
    )

    if opt.test:
        log_dict_train, preds = val(0, opt, val_loader, model, criterion)
        sio.savemat(os.path.join(opt.saveDir, 'preds.mat'), mdict = {'preds': preds})
        return
    # pyramidnet pretrain一次,先定义gen的训练数据loader
    train_loader = torch.utils.data.DataLoader(
            MPII(opt, 'train'), 
            batch_size = opt.trainBatch, 
            shuffle = True if opt.DEBUG == 0 else False,
            num_workers = int(ref.nThreads)
    )
    # 调用train方法
    for epoch in range(1, opt.nEpochs + 1):
        log_dict_train, _ = train(epoch, opt, train_loader, model, criterion, optimizer)
        for k, v in log_dict_train.items():
            logger.scalar_summary('train_{}'.format(k), v, epoch)
            logger.write('{} {:8f} | '.format(k, v))
        if epoch % opt.valIntervals == 0:
            log_dict_val, preds = val(epoch, opt, val_loader, model, criterion)
            for k, v in log_dict_val.items():
                logger.scalar_summary('val_{}'.format(k), v, epoch)
                logger.write('{} {:8f} | '.format(k, v))
            #saveModel(model, optimizer, os.path.join(opt.saveDir, 'model_{}.checkpoint'.format(epoch)))
            torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch)))
            sio.savemat(os.path.join(opt.saveDir, 'preds_{}.mat'.format(epoch)), mdict = {'preds': preds})
        logger.write('\n')
        if epoch % opt.dropLR == 0:
            lr = opt.LR * (0.1 ** (epoch // opt.dropLR))
            print('Drop LR to {}'.format(lr))
            adjust_learning_rate(optimizer, lr)
    logger.close()
    torch.save(model.cpu(), os.path.join(opt.saveDir, 'model_cpu.pth'))
Exemplo n.º 5
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict, device):

    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_loss1 = AverageMeter()
    ave_aux_loss = AverageMeter()
    ave_error_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)
        losses, aux_loss, error_loss, _ = model(images, labels)
        # print('pred', pred[2].size())
        loss = losses.mean() + 0.4 * aux_loss.mean() + 1 * error_loss.mean()

        reduced_loss = reduce_tensor(loss)
        loss1 = reduce_tensor(losses)
        aux_loss = reduce_tensor(aux_loss)
        error_losses = reduce_tensor(error_loss)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_loss1.update(loss1.item())
        ave_aux_loss.update(aux_loss.item())
        ave_error_loss.update(error_losses.item())

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            print_loss1 = ave_loss1.average() / world_size
            print_loss_aux = ave_aux_loss.average() / world_size
            print_error_loss = ave_error_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}, Loss_1: {:.6f}, Loss_aux: {:.6f}, error_loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss, print_loss1, print_loss_aux, print_error_loss)
            logging.info(msg)

            writer.add_scalar('train_loss', print_loss, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
Exemplo n.º 6
0
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        print("Epoch " + str(epoch) + ':') 
        tbar = tqdm(self.train_loader)

        for i, (input, heatmap, centermap, img_path, limbsmap) in enumerate(tbar):
            learning_rate = adjust_learning_rate(self.optimizer, self.iters, self.lr, policy='step',
                                                 gamma=self.gamma, step_size=self.step_size)

            input_var     =     input.cuda()
            heatmap_var   =    heatmap.cuda()
            limbs_var     =   limbsmap.cuda()

            self.optimizer.zero_grad()

            heat, limbs = self.model(input_var)

            loss_heat   = self.criterion(heat,  heatmap_var)

            loss = loss_heat

            train_loss += loss_heat.item()

            loss.backward()
            self.optimizer.step()

            tbar.set_description('Train loss: %.6f' % (train_loss / ((i + 1)*self.batch_size)))

            self.iters += 1

            if i == 10000:
            	break
Exemplo n.º 7
0
def train_epoch(model, optimizer, start_epoch=0):
    for epoch in range(start_epoch, hp.max_epoch):
        start_time = time.time()
        train_loop(model, optimizer, epoch)
        if (epoch + 1) % hp.save_per_epoch == 0 or (
                epoch + 1) % hp.reset_optimizer_epoch > 30:
            torch.save(model.state_dict(),
                       hp.save_dir + "/network.epoch{}".format(epoch + 1))
            torch.save(
                optimizer.state_dict(),
                hp.save_dir + "/network.optimizer.epoch{}".format(epoch + 1))
        adjust_learning_rate(optimizer, epoch + 1)
        if (epoch + 1) % hp.reset_optimizer_epoch == 0:
            optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
        print("EPOCH {} end".format(epoch + 1))
        print(f'elapsed time = {(time.time()-start_time)//60}m')
Exemplo n.º 8
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict):
    # Training
    model.train()
    scaler = GradScaler()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_acc = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']

    for i_iter, batch in enumerate(trainloader, 0):
        images, labels, _, _ = batch
        images = images.cuda()
        # print("images:",images.size())
        labels = labels.long().cuda()
        # print("label:",labels.size())
        with autocast():
            losses, _, acc = model(images, labels)
        loss = losses.mean()
        acc = acc.mean()

        if dist.is_distributed():
            reduced_loss = reduce_tensor(loss)
        else:
            reduced_loss = loss

        model.zero_grad()
        scaler.scale(loss).backward()
        #  loss.backward()

        #optimizer.step()
        scaler.step(optimizer)
        scaler.update()
        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_acc.update(acc.item())

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and dist.get_rank() == 0:
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {}, Loss: {:.6f}, Acc:{:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), [x['lr'] for x in optimizer.param_groups], ave_loss.average(),
                      ave_acc.average())
            logging.info(msg)

    writer.add_scalar('train_loss', ave_loss.average(), global_steps)
    writer_dict['train_global_steps'] = global_steps + 1
Exemplo n.º 9
0
def train_epoch(model, optimizer, writer, args, hp, start_epoch=0):
    dataset_train = datasets.get_dataset(hp.train_script, hp, use_spec_aug=hp.use_spec_aug)
    train_sampler = DistributedSampler(dataset_train) if args.n_gpus > 1 else None
    dataloader = DataLoader(dataset_train, batch_size=hp.batch_size, shuffle=hp.shuffle, sampler=train_sampler,
                            num_workers=1, collate_fn=datasets.collate_fn, drop_last=True)
    step = len(dataloader) * start_epoch

    for epoch in range(start_epoch, hp.max_epoch):
        start_time = time.time()
        step = train_loop(model, optimizer, writer, step, args, hp)
        if (epoch + 1) % hp.save_per_epoch == 0 or (epoch+1) % hp.reset_optimizer_epoch > 10:
            torch.save(model.state_dict(), hp.save_dir + "/network.epoch{}".format(epoch + 1))
            torch.save(optimizer.state_dict(), hp.save_dir + "/network.optimizer.epoch{}".format(epoch + 1))
        if hp.encoder_type != 'Conformer':
            adjust_learning_rate(optimizer, epoch + 1)
        if (epoch + 1) % hp.reset_optimizer_epoch == 0:
            if hp.encoder_type != 'Conformer':
                optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
        print("EPOCH {} end".format(epoch + 1))
        print(f'elapsed time = {(time.time()-start_time)//60}m')
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
         trainloader, optimizer, lr_scheduler, model, writer_dict, device):
    
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        losses, _ = model(images, labels, train_step=(lr_scheduler._step_count-1))
        loss = losses.mean()

        reduced_loss = reduce_tensor(loss)

        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        if config.TRAIN.LR_SCHEDULER != 'step':
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters, 
                      batch_time.average(), lr, print_loss)
            logging.info(msg)
            
            writer.add_scalar('train_loss', print_loss, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
            batch_time = AverageMeter()
Exemplo n.º 11
0
def main_worker(ngpus_per_node, args):
    global best_acc1

    cprint('=> modeling the network ...', 'green')
    model = magface.builder(args)
    model = torch.nn.DataParallel(model).cuda()
    # for name, param in model.named_parameters():
    #     cprint(' : layer name and parameter size - {} - {}'.format(name, param.size()), 'green')

    cprint('=> building the oprimizer ...', 'green')
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    pprint.pprint(optimizer)

    cprint('=> building the dataloader ...', 'green')
    train_loader = dataloader.train_loader(args)

    cprint('=> building the criterion ...', 'green')
    criterion = magface.MagLoss(args.l_a, args.u_a, args.l_margin,
                                args.u_margin)

    global iters
    iters = 0

    cprint('=> starting training engine ...', 'green')
    for epoch in range(args.start_epoch, args.epochs):

        global current_lr
        current_lr = utils.adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        do_train(train_loader, model, criterion, optimizer, epoch, args)

        # save pth
        if epoch % args.pth_save_epoch == 0:
            state_dict = model.state_dict()

            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': state_dict,
                    'optimizer': optimizer.state_dict(),
                },
                False,
                filename=os.path.join(
                    args.pth_save_fold,
                    '{}.pth'.format(str(epoch + args.start_epoch).zfill(5))))
            cprint(' : save pth for epoch {}'.format(epoch + 1))
Exemplo n.º 12
0
def train(epoch):
    model.train()
    utils.adjust_learning_rate(optimizer, epoch, args.step_size, args.lr,
                               args.gamma)
    print('epoch =', epoch, 'lr = ', optimizer.param_groups[0]['lr'])
    for iteration, (lr_tensor, hr_tensor) in enumerate(training_data_loader,
                                                       1):

        if args.cuda:
            lr_tensor = lr_tensor.to(device)  # ranges from [0, 1]
            hr_tensor = hr_tensor.to(device)  # ranges from [0, 1]

        optimizer.zero_grad()
        sr_tensor = model(lr_tensor)
        loss_l1 = l1_criterion(sr_tensor, hr_tensor)
        loss_sr = loss_l1

        loss_sr.backward()
        optimizer.step()
        # if iteration % 1000 == 0:
        print("===> Epoch[{}]({}/{}): Loss_l1: {:.5f}".format(
            epoch, iteration, len(training_data_loader), loss_l1.item()))
Exemplo n.º 13
0
    def client_update(self, optimizer, optimizer_args, local_epoch, n_round):
        self.model.train()
        self.model.to(self.device)

        optimizer = optimizer(self.model.parameters(), **optimizer_args)
        if n_round in self.args.lr_stage == 0:
            adjust_learning_rate(optimizer, self.args.lr, 0.1, self.index_step)
            self.index_step += 1

        # to do
        criterion = MultiBoxLoss(self.cfg['num_classes'], 0.5, True, 0, True,
                                 3, 0.5, False, self.args.cuda)

        print("Client training ...")
        for epoch in range(local_epoch):
            for images, targets in tqdm(self.dataloader):
                # load train data
                if self.args.cuda:
                    images = Variable(images.cuda())
                    targets = [
                        Variable(ann.cuda(), volatile=True) for ann in targets
                    ]
                else:
                    images = Variable(images)
                    targets = [Variable(ann, volatile=True) for ann in targets]

                # forward
                out = self.model(images)
                # backprop
                optimizer.zero_grad()
                loss_l, loss_c = criterion(out, targets)
                loss = loss_l + loss_c
                loss.backward()
                optimizer.step()

        self.model.to("cpu")
Exemplo n.º 14
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, 
        num_iters, trainloader, optimizer, model, writer_dict):

    if config.DATASET.DATASET == "pneumothorax":
        trainloader.dataset.update_train_ds(config.DATASET.WEIGHT_POSITIVE)
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    for i_iter, batch in enumerate(trainloader, 0):
        images, labels, _, _ = batch
        # import pdb; pdb.set_trace()
        labels = labels.long().cuda()
        losses, _ = model(images, labels)
        loss = losses.mean()

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(loss.item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters, 
                      batch_time.average(), lr, ave_loss.average())
            logging.info(msg)

    writer.add_scalar('train_loss', ave_loss.average(), global_steps)
    writer_dict['train_global_steps'] = global_steps + 1
Exemplo n.º 15
0
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        print("Epoch " + str(epoch) + ':')
        tbar = tqdm(self.train_loader)

        for i, (input, heatmap, centermap, img_path, segmented,
                box) in enumerate(tbar):
            learning_rate = adjust_learning_rate(self.optimizer,
                                                 self.iters,
                                                 self.lr,
                                                 policy='step',
                                                 gamma=self.gamma,
                                                 step_size=self.step_size)

            input_var = input.cuda()
            heatmap_var = heatmap.cuda()
            centermap_var = centermap.cuda()

            self.optimizer.zero_grad()

            heat = torch.zeros(self.numClasses + 1, 46, 46).cuda()
            cell = torch.zeros(15, 46, 46).cuda()
            hide = torch.zeros(15, 46, 46).cuda()

            losses = {}
            loss = 0

            start_model = time.time()
            for j in range(self.frame_memory):
                heat, cell, hide = self.model(input_var, centermap_var, j,
                                              heat, hide, cell)

                losses[j] = self.criterion(heat, heatmap_var[0:, j])
                loss += losses[j]

            train_loss += loss.item()

            loss.backward()
            self.optimizer.step()

            tbar.set_description('Train loss: %.6f' %
                                 (train_loss / ((i + 1) * self.batch_size)))

            self.iters += 1
Exemplo n.º 16
0
def main():
    global args
    print("config:{0}".format(args))

    checkpoint_dir = args.checkpoint_folder

    global_step = 0
    min_val_loss = 999999999

    title = 'train|val loss '
    init = np.NaN
    win_feats5 = viz.line(
        X=np.column_stack((np.array([init]), np.array([init]))),
        Y=np.column_stack((np.array([init]), np.array([init]))),
        opts={
            'title': title,
            'xlabel': 'Iter',
            'ylabel': 'Loss',
            'legend': ['train_feats5', 'val_feats5']
        },
    )

    win_fusion = viz.line(
        X=np.column_stack((np.array([init]), np.array([init]))),
        Y=np.column_stack((np.array([init]), np.array([init]))),
        opts={
            'title': title,
            'xlabel': 'Iter',
            'ylabel': 'Loss',
            'legend': ['train_fusion', 'val_fusion']
        },
    )

    train_loader, val_loader = prep_SBD_dataset.get_dataloader(args)
    model = CASENet_resnet101(pretrained=False, num_classes=args.cls_num)

    if args.multigpu:
        model = torch.nn.DataParallel(model.cuda())
    else:
        model = model.cuda()

    policies = get_model_policy(model)  # Set the lr_mult=10 of new layer
    optimizer = torch.optim.SGD(policies,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if args.pretrained_model:
        utils.load_pretrained_model(model, args.pretrained_model)

    if args.resume_model:
        checkpoint = torch.load(args.resume_model)
        args.start_epoch = checkpoint['epoch'] + 1
        min_val_loss = checkpoint['min_loss']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    for epoch in range(args.start_epoch, args.epochs):
        curr_lr = utils.adjust_learning_rate(args.lr, args, optimizer,
                                             global_step, args.lr_steps)

        global_step = model_play.train(args, train_loader, model, optimizer, epoch, curr_lr,\
                                 win_feats5, win_fusion, viz, global_step)

        curr_loss = model_play.validate(args, val_loader, model, epoch,
                                        win_feats5, win_fusion, viz,
                                        global_step)

        # Always store current model to avoid process crashed by accident.
        utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'min_loss': min_val_loss,
            },
            epoch,
            folder=checkpoint_dir,
            filename="curr_checkpoint.pth.tar")

        if curr_loss < min_val_loss:
            min_val_loss = curr_loss
            utils.save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'min_loss': min_val_loss,
                },
                epoch,
                folder=checkpoint_dir)
            print("Min loss is {0}, in {1} epoch.".format(min_val_loss, epoch))
Exemplo n.º 17
0
def main():

    args = parse_args()
    args.pretrain = False

    root_path = 'exps/exp_{}'.format(args.exp)

    if not os.path.exists(root_path):
        os.mkdir(root_path)
        os.mkdir(os.path.join(root_path, "log"))
        os.mkdir(os.path.join(root_path, "model"))

    base_lr = args.lr  # base learning rate

    train_dataset, val_dataset = build_dataset(args.dataset, args.data_root,
                                               args.train_list)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = VNet(args.n_channels, args.n_classes).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7)

    model = torch.nn.DataParallel(model)

    model.train()

    if args.resume is None:
        assert os.path.exists(args.load_path)
        state_dict = model.state_dict()
        print("Loading weights...")
        pretrain_state_dict = torch.load(args.load_path,
                                         map_location="cpu")['state_dict']

        for k in list(pretrain_state_dict.keys()):
            if k not in state_dict:
                del pretrain_state_dict[k]
        model.load_state_dict(pretrain_state_dict)
        print("Loaded weights")
    else:
        print("Resuming from {}".format(args.resume))
        checkpoint = torch.load(args.resume, map_location="cpu")

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.load_state_dict(checkpoint['state_dict'])

    logger = Logger(root_path)
    saver = Saver(root_path)

    for epoch in range(args.start_epoch, args.epochs):
        train(model, train_loader, optimizer, logger, args, epoch)
        validate(model, val_loader, optimizer, logger, saver, args, epoch)
        adjust_learning_rate(args, optimizer, epoch)
Exemplo n.º 18
0
def main():

    args = parse_args()
    if args.turnon < 0:
        args.pretrain = True
    else:
        args.pretrain = False
    print("Using GPU: {}".format(args.local_rank))
    root_path = 'exps/exp_{}'.format(args.exp)
    if args.local_rank == 0 and not os.path.exists(root_path):
        os.mkdir(root_path)
        os.mkdir(os.path.join(root_path, "log"))
        os.mkdir(os.path.join(root_path, "model"))

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    train_dataset, val_dataset = build_dataset(args.dataset,
                                               args.data_root,
                                               args.train_list,
                                               sampling=args.sampling)
    args.world_size = len(args.gpu.split(","))
    if args.world_size > 1:
        os.environ['MASTER_PORT'] = args.port
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group('nccl')
        device = torch.device('cuda:{}'.format(args.local_rank))
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=len(args.gpu.split(",")),
            rank=args.local_rank)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               sampler=train_sampler,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = VNet(args.n_channels, args.n_classes, input_size=64,
                 pretrain=True).cuda(args.local_rank)
    model_ema = VNet(args.n_channels,
                     args.n_classes,
                     input_size=64,
                     pretrain=True).cuda(args.local_rank)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    if args.world_size > 1:
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
        model_ema = DDP(model_ema,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank,
                        find_unused_parameters=True)

    model.train()
    model_ema.load_state_dict(model.state_dict())
    print("Loaded weights")

    logger = Logger(root_path)
    saver = Saver(root_path, save_freq=args.save_freq)
    if args.sampling == 'default':
        contrast = RGBMoCo(128, K=4096,
                           T=args.temperature).cuda(args.local_rank)
    elif args.sampling == 'layerwise':
        contrast = RGBMoCoNew(128, K=4096,
                              T=args.temperature).cuda(args.local_rank)
    else:
        raise ValueError("unsupported sampling method")
    criterion = torch.nn.CrossEntropyLoss()

    flag = False
    for epoch in range(args.start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        train(model, model_ema, train_loader, optimizer, logger, saver, args,
              epoch, contrast, criterion)
        validate(model_ema, val_loader, optimizer, logger, saver, args, epoch)
        adjust_learning_rate(args, optimizer, epoch)
Exemplo n.º 19
0
def main_worker(gpu, args):
    # check the feasible of the lambda g
    s = 64
    k = (args.u_margin-args.l_margin)/(args.u_a-args.l_a)
    min_lambda = s*k*args.u_a**2*args.l_a**2/(args.u_a**2-args.l_a**2)
    color_lambda = 'red' if args.lambda_g < min_lambda else 'green'
    ngpus_per_node = torch.cuda.device_count()
    
    args.gpu = gpu
    args.rank = args.nr * args.gpus + args.gpu
    torch.cuda.set_device(gpu)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank)
    torch.manual_seed(0)
    # logs
    if args.rank == 0:
        cprint('min lambda g is {}, currrent lambda is {}'.format(
        min_lambda, args.lambda_g), color_lambda)
        cprint('=> torch version : {}'.format(torch.__version__), 'green')
        cprint('=> ngpus : {}'.format(ngpus_per_node), 'green')
    
    # init torchshard
    ts.distributed.init_process_group(group_size=args.world_size)

    global best_acc1
    if args.rank == 0:
        cprint('=> modeling the network ...', 'green')
    model = magface_dist.builder(args)
    # for name, param in model.named_parameters():
    #     cprint(' : layer name and parameter size - {} - {}'.format(name, param.size()), 'green')

    if args.rank == 0:
        cprint('=> building the oprimizer ...', 'green')
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    
    if args.rank == 0:
        pprint.pprint(optimizer)
        cprint('=> building the dataloader ...', 'green')
        cprint('=> building the criterion ...', 'green')

    grad_scaler = GradScaler(enabled=args.amp_mode)
    train_loader = dataloader_dist.train_loader(args)
    from models.parallel_magloss import ParallelMagLoss
    criterion = ParallelMagLoss(
        args.l_a, args.u_a, args.l_margin, args.u_margin)

    global iters
    iters = 0
    if args.rank == 0:
        cprint('=> starting training engine ...', 'green')
    for epoch in range(args.start_epoch, args.epochs):
        global current_lr
        current_lr = utils.adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        do_train(train_loader, model, criterion, optimizer, grad_scaler, epoch, args)

        # ts.collect_state_dict() needs to see all the process groups
        state_dict = model.state_dict()
        state_dict = ts.collect_state_dict(model, state_dict)
    
        # save pth
        if epoch % args.pth_save_epoch == 0:
            utils.save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': state_dict,
                'optimizer': optimizer.state_dict(),
             }, False,
             filename=os.path.join(
                args.pth_save_fold, '{}.pth'.format(str(epoch+1).zfill(5)))
             )            
            cprint(' : save pth for epoch {}'.format(epoch + 1))
Exemplo n.º 20
0
def main():

    #########  configs ###########
    best_metric = 0
    ######  load datasets ########
    train_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    train_data = dates.Dataset(cfg.TRAIN_DATA_PATH,
                               cfg.TRAIN_LABEL_PATH,
                               cfg.TRAIN_TXT_PATH,
                               'train',
                               transform=True,
                               transform_med=train_transform_det)
    train_loader = Data.DataLoader(train_data,
                                   batch_size=cfg.BATCH_SIZE,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=True)
    val_data = dates.Dataset(cfg.VAL_DATA_PATH,
                             cfg.VAL_LABEL_PATH,
                             cfg.VAL_TXT_PATH,
                             'val',
                             transform=True,
                             transform_med=val_transform_det)
    val_loader = Data.DataLoader(val_data,
                                 batch_size=cfg.BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=True)
    ######  build  models ########
    base_seg_model = 'deeplab'
    if base_seg_model == 'deeplab':
        import model.siameseNet.deeplab_v2 as models
        pretrain_deeplab_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                             'deeplab_v2_voc12.pth')
        model = models.SiameseNet(norm_flag='l2')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            deeplab_pretrain_model = torch.load(pretrain_deeplab_path)
            model.init_parameters_from_deeplab(deeplab_pretrain_model)
    else:
        import model.siameseNet.fcn32s_tiny as models
        pretrain_vgg_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                         'vgg16_from_caffe.pth')
        model = models.SiameseNet(distance_flag='softmax')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            vgg_pretrain_model = util.load_pretrain_model(pretrain_vgg_path)
            model.init_parameters(vgg_pretrain_model)

    model = model.cuda()
    MaskLoss = ls.ConstractiveMaskLoss()
    ab_test_dir = os.path.join(cfg.SAVE_PRED_PATH, 'contrastive_loss')
    check_dir(ab_test_dir)
    save_change_map_dir = os.path.join(ab_test_dir, 'changemaps/')
    save_valid_dir = os.path.join(ab_test_dir, 'valid_imgs')
    save_roc_dir = os.path.join(ab_test_dir, 'roc')
    check_dir(save_change_map_dir), check_dir(save_valid_dir), check_dir(
        save_roc_dir)
    #########
    ######### optimizer ##########
    ######## how to set different learning rate for differernt layers #########
    optimizer = torch.optim.SGD(
        [{
            'params': set_base_learning_rate_for_multi_layer(model),
            'lr': cfg.INIT_LEARNING_RATE
        }, {
            'params': set_2x_learning_rate_for_multi_layer(model),
            'lr': 2 * cfg.INIT_LEARNING_RATE,
            'weight_decay': 0
        }, {
            'params': set_10x_learning_rate_for_multi_layer(model),
            'lr': 10 * cfg.INIT_LEARNING_RATE
        }, {
            'params': set_20x_learning_rate_for_multi_layer(model),
            'lr': 20 * cfg.INIT_LEARNING_RATE,
            'weight_decay': 0
        }],
        lr=cfg.INIT_LEARNING_RATE,
        momentum=cfg.MOMENTUM,
        weight_decay=cfg.DECAY)
    ######## iter img_label pairs ###########
    loss_total = 0
    for epoch in range(100):
        for batch_idx, batch in enumerate(train_loader):
            step = epoch * len(train_loader) + batch_idx
            util.adjust_learning_rate(cfg.INIT_LEARNING_RATE, optimizer, step)
            model.train()
            img1_idx, img2_idx, label_idx, filename, height, width = batch
            img1, img2, label = Variable(img1_idx.cuda()), Variable(
                img2_idx.cuda()), Variable(label_idx.cuda())
            out_conv5, out_fc, out_embedding = model(img1, img2)
            out_conv5_t0, out_conv5_t1 = out_conv5
            out_fc_t0, out_fc_t1 = out_fc
            out_embedding_t0, out_embedding_t1 = out_embedding
            label_rz_conv5 = Variable(
                util.resize_label(
                    label.data.cpu().numpy(),
                    size=out_conv5_t0.data.cpu().numpy().shape[2:]).cuda())
            label_rz_fc = Variable(
                util.resize_label(
                    label.data.cpu().numpy(),
                    size=out_fc_t0.data.cpu().numpy().shape[2:]).cuda())
            label_rz_embedding = Variable(
                util.resize_label(
                    label.data.cpu().numpy(),
                    size=out_embedding_t0.data.cpu().numpy().shape[2:]).cuda())
            contractive_loss_conv5 = MaskLoss(out_conv5_t0, out_conv5_t1,
                                              label_rz_conv5)
            contractive_loss_fc = MaskLoss(out_fc_t0, out_fc_t1, label_rz_fc)
            contractive_loss_embedding = MaskLoss(out_embedding_t0,
                                                  out_embedding_t1,
                                                  label_rz_embedding)
            loss = contractive_loss_conv5 + contractive_loss_fc + contractive_loss_embedding
            loss_total += loss.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch_idx) % 20 == 0:
                print(
                    "Epoch [%d/%d] Loss: %.4f Mask_Loss_conv5: %.4f Mask_Loss_fc: %.4f "
                    "Mask_Loss_embedding: %.4f" %
                    (epoch, batch_idx, loss.data[0],
                     contractive_loss_conv5.data[0],
                     contractive_loss_fc.data[0],
                     contractive_loss_embedding.data[0]))
            if (batch_idx) % 1000 == 0:
                model.eval()
                current_metric = validate(model, val_loader, epoch,
                                          save_change_map_dir, save_roc_dir)
                if current_metric > best_metric:
                    torch.save({'state_dict': model.state_dict()},
                               os.path.join(ab_test_dir,
                                            'model' + str(epoch) + '.pth'))
                    shutil.copy(
                        os.path.join(ab_test_dir,
                                     'model' + str(epoch) + '.pth'),
                        os.path.join(ab_test_dir, 'model_best.pth'))
                    best_metric = current_metric
        current_metric = validate(model, val_loader, epoch,
                                  save_change_map_dir, save_roc_dir)
        if current_metric > best_metric:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(ab_test_dir,
                                    'model' + str(epoch) + '.pth'))
            shutil.copy(
                os.path.join(ab_test_dir, 'model' + str(epoch) + '.pth'),
                os.path.join(ab_test_dir, 'model_best.pth'))
            best_metric = current_metric
        if epoch % 5 == 0:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(ab_test_dir,
                                    'model' + str(epoch) + '.pth'))
Exemplo n.º 21
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    miml = MIML(freeze=False, fine_tune=True)

    pretrained_net_dict = torch.load(
        checkpoint_miml, map_location=lambda storage, loc: storage)['model']
    new_state_dict = OrderedDict()
    for k, v in pretrained_net_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
        # load params
    miml.load_state_dict(new_state_dict)
    miml_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                    miml.parameters()),
                                      lr=miml_lr,
                                      weight_decay=1e-4)

    decoder = Decoder(attrs_dim=attrs_dim,
                      embed_dim=emb_dim,
                      decoder_dim=decoder_dim,
                      attrs_size=attrs_size,
                      vocab_size=len(word_map),
                      device=device,
                      dropout=dropout)

    decoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr,
                                         weight_decay=1e-4)

    if checkpoint:
        checkpoint = torch.load(checkpoint,
                                map_location=lambda storage, loc: storage)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        miml.load_state_dict(checkpoint['miml'])
        decoder.load_state_dict(checkpoint['decoder'])
        decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer'])

    miml = miml.to(device)
    decoder = decoder.to(device)
    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=pin_memory)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=pin_memory)
    writer = SummaryWriter(log_dir='./log_miml')
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 5 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 4 == 0:
            adjust_learning_rate(decoder_optimizer, 0.9)

        # One epoch's training
        train(train_loader=train_loader,
              miml=miml,
              decoder=decoder,
              criterion=criterion,
              miml_optimizer=miml_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              writer=writer)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                miml=miml,
                                decoder=decoder,
                                criterion=criterion,
                                epoch=epoch,
                                writer=writer)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint_miml(prefix, data_name, epoch,
                             epochs_since_improvement, miml, decoder,
                             miml_optimizer, decoder_optimizer, recent_bleu4,
                             is_best)
Exemplo n.º 22
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   device=device,
                                   dropout=dropout)
    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                         lr=encoder_lr) if fine_tune_encoder else None

    if checkpoint:
        checkpoint = torch.load(
            checkpoint, map_location=lambda storage, loc: storage)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder.load_state_dict(checkpoint['decoder'])
        decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer'])
        encoder.load_state_dict(checkpoint['encoder'])
        if fine_tune_encoder:
            encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer'])

    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN',
                       transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=False)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL',
                       transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=False)
    writer = SummaryWriter(log_dir='./log_basemodel')
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              writer=writer)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                epoch=epoch,
                                writer=writer)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint_basemodel(prefix, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                                  decoder_optimizer, recent_bleu4, is_best)
Exemplo n.º 23
0
def adj_lr(optimizer, epoch):
    if epoch in [120, 150]:
        adjust_learning_rate(optimizer)
    return optimizer.param_groups[0]['lr']
Exemplo n.º 24
0
def train(ckpt, num_epochs, batch_size, device):
    start = time.time()

    num_workers = 0
    lr = 8e-4
    momentum = 0
    weight_decay = 0

    directory = 'data/'
    start_epoch = 0
    start_loss = 0
    print_freq = 10
    checkpoint_interval = 1
    evaluation_interval = 1

    logger = Logger('./logs')

    model, train_dataset, val_dataset, criterion_grid, optimizer = init_model_and_dataset(
        directory, device, lr, weight_decay)

    # load the pretrained network
    if ckpt is not None:
        checkpoint = torch.load(ckpt)

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_loss = checkpoint['loss']

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=True)

    for epoch in range(start_epoch, num_epochs):
        adjust_learning_rate(optimizer, epoch, lr)

        # train for one epoch
        batch_time = AverageMeter()
        total_time = AverageMeter()
        train_loss = AverageMeter()

        train_fingers_recall = AverageMeter()
        train_fingers_precision = AverageMeter()

        train_frets_recall = AverageMeter()
        train_frets_precision = AverageMeter()

        train_strings_recall = AverageMeter()
        train_strings_precision = AverageMeter()

        train_loss.update(start_loss)

        # switch to train mode
        model.train()

        for data_idx, data in enumerate(train_loader):
            batch_start = time.time()
            input = data['image'].float().to(device)
            target = data['fingers'].float().to(device)
            frets = data['frets'].float().to(device)
            strings = data['strings'].float().to(device)
            target_coord = data['finger_coord']
            frets_coord = data['fret_coord']
            strings_coord = data['string_coord']

            # compute output
            output = model(input)
            output1 = output[0].split(input.shape[0], dim=0)
            output2 = output[1].split(input.shape[0], dim=0)
            output3 = output[2].split(input.shape[0], dim=0)

            loss1 = sum(criterion_grid(o, target) for o in output1)
            loss2 = sum(criterion_grid(o, frets) for o in output2)
            loss3 = sum(criterion_grid(o, strings) for o in output3)

            loss = loss1 / 2 + loss2 + loss3

            # measure accuracy and record loss
            accuracy(output=output1[-1].data,
                     target=target,
                     global_precision=train_fingers_precision,
                     global_recall=train_fingers_recall,
                     fingers=target_coord,
                     min_dist=10)

            accuracy(output=output2[-1].data,
                     target=frets,
                     global_precision=train_frets_precision,
                     global_recall=train_frets_recall,
                     fingers=frets_coord.unsqueeze(0),
                     min_dist=5)

            accuracy(output=output3[-1].data,
                     target=strings,
                     global_precision=train_strings_precision,
                     global_recall=train_strings_recall,
                     fingers=strings_coord.unsqueeze(0),
                     min_dist=5)

            train_loss.update(loss.item())

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - batch_start)
            total_time.update((time.time() - start) / 60)

            if data_idx % print_freq == 0 and data_idx != 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Loss.avg: {loss.avg:.4f}\t'
                    'Batch time: {batch_time:.4f} s\t'
                    'Total time: {total_time:.4f} min\n'
                    'FINGERS: \t'
                    'Recall(%): {top1:.3f}\t'
                    'Precision(%): {top2:.3f}\n'
                    'FRETS:   \t'
                    'Recall(%): {top6:.3f}\t'
                    'Precision(%): {top7:.3f}\n'
                    'STRINGS: \t'
                    'Recall(%): {top11:.3f}\t'
                    'Precision(%): {top12:.3f}\n'
                    '---------------------------------------------------------------------------------------------'
                    .format(epoch,
                            data_idx,
                            len(train_loader),
                            loss=train_loss,
                            batch_time=batch_time.val,
                            total_time=total_time.val,
                            top1=train_fingers_recall.avg * 100,
                            top2=train_fingers_precision.avg * 100,
                            top6=train_frets_recall.avg * 100,
                            top7=train_frets_precision.avg * 100,
                            top11=train_strings_recall.avg * 100,
                            top12=train_strings_precision.avg * 100))

        if epoch % evaluation_interval == 0:
            # evaluate on validation set
            print(
                '---------------------------------------------------------------------------------------------\n'
                'Train set:  ')

            t_recall1, t_recall2, t_recall3, t_precision1, t_precision2, t_precision3 = test(
                train_loader, model, device)
            print('Validation set:  ')
            e_recall1, e_recall2, e_recall3, e_precision1, e_precision2, e_precision3 = test(
                val_loader, model, device, show=False)

            print(
                '---------------------------------------------------------------------------------------------\n'
                '---------------------------------------------------------------------------------------------'
            )

            # 1. Log scalar values (scalar summary)
            info = {
                'Train Loss': train_loss.avg,
                '(Fingers) Train Recall': t_recall1,
                '(Fingers) Train Precision': t_precision1,
                '(Fingers) Validation Recall': e_recall1,
                '(Fingers) Validation Precision': e_precision1,
                '(Frets) Train Recall': t_recall2,
                '(Frets) Train Precision': t_precision2,
                '(Frets) Validation Recall': e_recall2,
                '(Frets) Validation Precision': e_precision2,
                '(Strings) Train Recall': t_recall3,
                '(Strings) Train Precision': t_precision3,
                '(Strings) Validation Recall': e_recall3,
                '(Strings) Validation Precision': e_precision3
            }

            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch)

            # 2. Log values and gradients of the parameters (histogram summary)
            for tag, value in model.named_parameters():
                tag = tag.replace('.', '/')
                try:
                    logger.histo_summary(tag, value.data.cpu().numpy(), epoch)
                except ValueError:
                    print('hey')
                logger.histo_summary(tag + '/grad',
                                     value.grad.data.cpu().numpy(), epoch)

            # 3. Log training images (image summary)
            info = {'images': input.view(-1, 300, 300).cpu().numpy()}

            for tag, images in info.items():
                logger.image_summary(tag, images, epoch)

        # remember best acc and save checkpoint
        if epoch % checkpoint_interval == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss.avg
                }, "checkpoints/hg_ckpt_{0}.pth".format(epoch))

        print('Epoch time: {time}'.format(time=(time.time() - start_epoch) /
                                          60))
Exemplo n.º 25
0
def main():
    # Parse the options
    opts = Opts().parse()
    opts.device = torch.device(f'cuda:{opts.gpu[0]}')
    print(opts.expID, opts.task)
    # Record the start time
    time_start = time.time()
    # TODO: select the dataset by the options
    # Set up dataset
    train_loader_unit = PENN_CROP(opts, 'train')
    train_loader = tud.DataLoader(train_loader_unit,
                                  batch_size=opts.trainBatch,
                                  shuffle=False,
                                  num_workers=int(opts.num_workers))
    val_loader = tud.DataLoader(PENN_CROP(opts, 'val'),
                                batch_size=1,
                                shuffle=False,
                                num_workers=int(opts.num_workers))

    # Read number of joints(dim of output) from dataset
    opts.nJoints = train_loader_unit.part.shape[1]
    # Create the Model, Optimizer and Criterion
    if opts.loadModel == 'none':
        model = Hourglass2DPrediction(opts).cuda(device=opts.device)
    else:
        model = torch.load(opts.loadModel).cuda(device=opts.device)
    # Set the Criterion and Optimizer
    criterion = torch.nn.MSELoss(reduce=False).cuda(device=opts.device)
    # opts.nOutput = len(model.outnode.children)
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    opts.LR,
                                    alpha=opts.alpha,
                                    eps=opts.epsilon,
                                    weight_decay=opts.weightDecay,
                                    momentum=opts.momentum)
    # If TEST, just validate
    # TODO: save the validate results to mat or hdf5
    if opts.test:
        loss_test, pck_test = val(0, opts, val_loader, model, criterion)
        print(f"test: | loss_test: {loss_test}| PCK_val: {pck_test}\n")
        ## TODO: save the predictions for the test
        #sio.savemat(os.path.join(opts.saveDir, 'preds.mat'), mdict = {'preds':preds})
        return
    # NOT TEST, Train and Validate
    for epoch in range(1, opts.nEpochs + 1):
        ## Train the model
        loss_train, pck_train = train(epoch, opts, train_loader, model,
                                      criterion, optimizer)
        ## Show results and elapsed time
        time_elapsed = time.time() - time_start
        print(
            f"epoch: {epoch} | loss_train: {loss_train} | PCK_train: {pck_train} | {time_elapsed//60:.0f}min {time_elapsed%60:.0f}s\n"
        )
        ## Intervals to show eval results
        if epoch % opts.valIntervals == 0:
            # TODO: Test the validation part
            ### Validation
            loss_val, pck_val = val(epoch, opts, val_loader, model, criterion)
            print(
                f"epoch: {epoch} | loss_val: {loss_val}| PCK_val: {pck_val}\n")
            ### Save the model
            torch.save(model, os.path.join(opts.save_dir,
                                           f"model_{epoch}.pth"))
            ### TODO: save the preds for the validation
            #sio.savemat(os.path.join(opts.saveDir, f"preds_{epoch}.mat"), mdict={'preds':preds})
        # Use the optimizer to adjust learning rate
        if epoch % opts.dropLR == 0:
            lr = adjust_learning_rate(optimizer, epoch, opts.dropLR, opts.LR)
            print(f"Drop LR to {lr}\n")
Exemplo n.º 26
0
def train_val(model, args):

    train_dir = args.train_dir
    val_dir = args.val_dir

    config = Config(args.config)
    cudnn.benchmark = True

    # train
    train_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
        'lspet', train_dir, 8,
        Mytransforms.Compose([
            Mytransforms.RandomResized(),
            Mytransforms.RandomRotate(40),
            Mytransforms.RandomCrop(368),
            Mytransforms.RandomHorizontalFlip(),
        ])),
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True)
    # val
    if args.val_dir is not None and config.test_interval != 0:
        # val
        val_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
            'lsp', val_dir, 8,
            Mytransforms.Compose([
                Mytransforms.TestResized(368),
            ])),
                                                 batch_size=config.batch_size,
                                                 shuffle=True,
                                                 num_workers=config.workers,
                                                 pin_memory=True)

    if args.gpu[0] < 0:
        criterion = nn.MSELoss()
    else:
        criterion = nn.MSELoss().cuda()

    params, multiple = get_parameters(model, config, True)
    # params, multiple = get_parameters(model, config, False)

    optimizer = torch.optim.SGD(params,
                                config.base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_list = [AverageMeter() for i in range(6)]
    end = time.time()
    iters = config.start_iters
    best_model = config.best_model

    heat_weight = 46 * 46 * 15 / 1.0

    losstracker1 = []

    losstracker2 = []
    losstracker3 = []
    losstracker4 = []
    losstracker5 = []
    losstracker6 = []
    while iters < config.max_iter:

        for i, (input, heatmap, centermap) in enumerate(train_loader):

            learning_rate = adjust_learning_rate(
                optimizer,
                iters,
                config.base_lr,
                policy=config.lr_policy,
                policy_parameter=config.policy_parameter,
                multiple=multiple)
            data_time.update(time.time() - end)

            if args.gpu[0] >= 0:
                heatmap = heatmap.cuda(async=True)
                centermap = centermap.cuda(async=True)

            input_var = torch.autograd.Variable(input)
            heatmap_var = torch.autograd.Variable(heatmap)
            centermap_var = torch.autograd.Variable(centermap)

            heat1, heat2, heat3, heat4, heat5, heat6 = model(
                input_var, centermap_var)

            loss1 = criterion(heat1, heatmap_var) * heat_weight
            loss2 = criterion(heat2, heatmap_var) * heat_weight
            loss3 = criterion(heat3, heatmap_var) * heat_weight
            loss4 = criterion(heat4, heatmap_var) * heat_weight
            loss5 = criterion(heat5, heatmap_var) * heat_weight
            loss6 = criterion(heat6, heatmap_var) * heat_weight

            loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            #print(input.size(0).item())
            losses.update(loss.item(), input.size(0))
            for cnt, l in enumerate([loss1, loss2, loss3, loss4, loss5,
                                     loss6]):
                losses_list[cnt].update(l.item(), input.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            iters += 1
            if iters % config.display == 0:
                print(
                    'Train Iteration: {0}\t'
                    'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {2}\n'
                    'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                        iters,
                        config.display,
                        learning_rate,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses))
                for cnt in range(0, 6):
                    print(
                        'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'.
                        format(cnt + 1, loss1=losses_list[cnt]))

                print(
                    time.strftime(
                        '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                        time.localtime()))

                batch_time.reset()
                data_time.reset()
                losses.reset()
                for cnt in range(6):
                    losses_list[cnt].reset()

            save_checkpoint({
                'iter': iters,
                'state_dict': model.state_dict(),
            }, 0, args.model_name)

            # val
            if args.val_dir is not None and config.test_interval != 0 and iters % config.test_interval == 0:

                model.eval()
                for j, (input, heatmap, centermap) in enumerate(val_loader):
                    if args.cuda[0] >= 0:
                        heatmap = heatmap.cuda(async=True)
                        centermap = centermap.cuda(async=True)

                    input_var = torch.autograd.Variable(input)
                    heatmap_var = torch.autograd.Variable(heatmap)
                    centermap_var = torch.autograd.Variable(centermap)

                    heat1, heat2, heat3, heat4, heat5, heat6 = model(
                        input_var, centermap_var)

                    loss1 = criterion(heat1, heatmap_var) * heat_weight
                    loss2 = criterion(heat2, heatmap_var) * heat_weight
                    loss3 = criterion(heat3, heatmap_var) * heat_weight
                    loss4 = criterion(heat4, heatmap_var) * heat_weight
                    loss5 = criterion(heat5, heatmap_var) * heat_weight
                    loss6 = criterion(heat6, heatmap_var) * heat_weight

                    loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
                    losses.update(loss.data[0], input.size(0))
                    for cnt, l in enumerate(
                        [loss1, loss2, loss3, loss4, loss5, loss6]):
                        losses_list[cnt].update(l.data[0], input.size(0))

                    batch_time.update(time.time() - end)
                    end = time.time()
                    is_best = losses.avg < best_model
                    best_model = min(best_model, losses.avg)
                    save_checkpoint(
                        {
                            'iter': iters,
                            'state_dict': model.state_dict(),
                        }, is_best, args.model_name)

                    if j % config.display == 0:
                        print(
                            'Test Iteration: {0}\t'
                            'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                            'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                            'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.
                            format(j,
                                   config.display,
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses))
                        for cnt in range(0, 6):
                            print(
                                'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'
                                .format(cnt + 1, loss1=losses_list[cnt]))

                        print(
                            time.strftime(
                                '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                                time.localtime()))
                        batch_time.reset()
                        losses.reset()
                        for cnt in range(6):
                            losses_list[cnt].reset()

                        losstracker1.append(loss1)
                        losstracker2.append(loss2)
                        losstracker3.append(loss3)
                        losstracker4.append(loss4)
                        losstracker5.append(loss5)
                        losstracker6.append(loss6)
                model.train()

    np.save('loss1', np.asarray(losstracker1))
    np.save('loss2', np.asarray(losstracker2))
    np.save('loss3', np.asarray(losstracker3))
    np.save('loss4', np.asarray(losstracker4))
    np.save('loss5', np.asarray(losstracker5))
    np.save('loss6', np.asarray(losstracker6))
Exemplo n.º 27
0
def main(opt):

    # Set the random seed manually for reproducibility.
    if torch.cuda.is_available():
        torch.cuda.manual_seed(opt.seed)
    else:
        torch.manual_seed(opt.seed)

    train_loader = get_data_loader(opt,
                                   split='train',
                                   return_org_image=False)

    val_loader = get_data_loader(opt,
                                 split='val',
                                 return_org_image=False)

    output_dir = os.path.dirname(opt.output_file)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger.info('Building model...')

    model = LaneNet(cnn_type=opt.cnn_type, embed_dim=opt.embed_dim)
    model = DataParallelModel(model)

    criterion_disc = DiscriminativeLoss(delta_var=0.5,
                                        delta_dist=1.5,
                                        norm=2,
                                        usegpu=True)

    criterion_ce = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)

    if opt.start_from:
        logger.info('Restart training from %s', opt.start_from)
        checkpoint = torch.load(opt.start_from)
        model.load_state_dict(checkpoint['model'])

    if torch.cuda.is_available():
        criterion_disc.cuda()
        criterion_ce.cuda()
        model = model.cuda()

    logger.info("Start training...")
    best_loss = sys.maxsize
    best_epoch = 0

    for epoch in tqdm(range(opt.num_epochs), desc='Epoch: '):
        learning_rate = adjust_learning_rate(opt, optimizer, epoch)
        logger.info('===> Learning rate: %f: ', learning_rate)

        # train for one epoch
        train(
            opt,
            model,
            criterion_disc,
            criterion_ce,
            optimizer,
            train_loader)

        # validate at every val_step epoch
        if epoch % opt.val_step == 0:
            val_loss = test(
                opt,
                model,
                criterion_disc,
                criterion_ce,
                val_loader)
            logger.info('Val loss: %s\n', val_loss)

            loss = val_loss.avg
            if loss < best_loss:
                logger.info(
                    'Found new best loss: %.7f, previous loss: %.7f',
                    loss,
                    best_loss)
                best_loss = loss
                best_epoch = epoch

                logger.info('Saving new checkpoint to: %s', opt.output_file)
                torch.save({
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'best_loss': best_loss,
                    'best_epoch': best_epoch,
                    'opt': opt
                }, opt.output_file)

            else:
                logger.info(
                    'Current loss: %.7f, best loss is %.7f @ epoch %d',
                    loss,
                    best_loss,
                    best_epoch)

        if epoch - best_epoch > opt.max_patience:
            logger.info('Terminated by early stopping!')
            break
Exemplo n.º 28
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--world_size',
                        type=int,
                        default=1,
                        help='number of GPUs to use')

    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--wd',
                        type=float,
                        default=1e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--lr-decay-every',
                        type=int,
                        default=100,
                        help='learning rate decay by 10 every X epochs')
    parser.add_argument('--lr-decay-scalar',
                        type=float,
                        default=0.1,
                        help='--')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--run_test',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='run test only')

    parser.add_argument(
        '--limit_training_batches',
        type=int,
        default=-1,
        help='how many batches to do per training, -1 means as many as possible'
    )

    parser.add_argument('--no_grad_clip',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='turn off gradient clipping')

    parser.add_argument('--get_flops',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='add hooks to compute flops')

    parser.add_argument(
        '--get_inference_time',
        default=False,
        type=str2bool,
        nargs='?',
        help='runs valid multiple times and reports the result')

    parser.add_argument('--mgpu',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='use data paralization via multiple GPUs')

    parser.add_argument('--dataset',
                        default="MNIST",
                        type=str,
                        help='dataset for experiment, choice: MNIST, CIFAR10',
                        choices=["MNIST", "CIFAR10", "Imagenet"])

    parser.add_argument('--data',
                        metavar='DIR',
                        default='/imagenet',
                        help='path to imagenet dataset')

    parser.add_argument(
        '--model',
        default="lenet3",
        type=str,
        help='model selection, choices: lenet3, vgg, mobilenetv2, resnet18',
        choices=[
            "lenet3", "vgg", "mobilenetv2", "resnet18", "resnet152",
            "resnet50", "resnet50_noskip", "resnet20", "resnet34", "resnet101",
            "resnet101_noskip", "densenet201_imagenet", 'densenet121_imagenet'
        ])

    parser.add_argument('--tensorboard',
                        type=str2bool,
                        nargs='?',
                        help='Log progress to TensorBoard')

    parser.add_argument(
        '--save_models',
        default=True,
        type=str2bool,
        nargs='?',
        help='if True, models will be saved to the local folder')

    # ============================PRUNING added
    parser.add_argument(
        '--pruning_config',
        default=None,
        type=str,
        help=
        'path to pruning configuration file, will overwrite all pruning parameters in arguments'
    )

    parser.add_argument('--group_wd_coeff',
                        type=float,
                        default=0.0,
                        help='group weight decay')
    parser.add_argument('--name',
                        default='test',
                        type=str,
                        help='experiment name(folder) to store logs')

    parser.add_argument(
        '--augment',
        default=False,
        type=str2bool,
        nargs='?',
        help=
        'enable or not augmentation of training dataset, only for CIFAR, def False'
    )

    parser.add_argument('--load_model',
                        default='',
                        type=str,
                        help='path to model weights')

    parser.add_argument('--pruning',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='enable or not pruning, def False')

    parser.add_argument(
        '--pruning-threshold',
        '--pt',
        default=100.0,
        type=float,
        help=
        'Max error perc on validation set while pruning (default: 100.0 means always prune)'
    )

    parser.add_argument(
        '--pruning-momentum',
        default=0.0,
        type=float,
        help=
        'Use momentum on criteria between pruning iterations, def 0.0 means no momentum'
    )

    parser.add_argument('--pruning-step',
                        default=15,
                        type=int,
                        help='How often to check loss and do pruning step')

    parser.add_argument('--prune_per_iteration',
                        default=10,
                        type=int,
                        help='How many neurons to remove at each iteration')

    parser.add_argument(
        '--fixed_layer',
        default=-1,
        type=int,
        help='Prune only a given layer with index, use -1 to prune all')

    parser.add_argument('--start_pruning_after_n_iterations',
                        default=0,
                        type=int,
                        help='from which iteration to start pruning')

    parser.add_argument('--maximum_pruning_iterations',
                        default=1e8,
                        type=int,
                        help='maximum pruning iterations')

    parser.add_argument('--starting_neuron',
                        default=0,
                        type=int,
                        help='starting position for oracle pruning')

    parser.add_argument('--prune_neurons_max',
                        default=-1,
                        type=int,
                        help='prune_neurons_max')

    parser.add_argument('--pruning-method',
                        default=0,
                        type=int,
                        help='pruning method to be used, see readme.md')

    parser.add_argument('--pruning_fixed_criteria',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='enable or not criteria reevaluation, def False')

    parser.add_argument('--fixed_network',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='fix network for oracle or criteria computation')

    parser.add_argument(
        '--zero_lr_for_epochs',
        default=-1,
        type=int,
        help='Learning rate will be set to 0 for given number of updates')

    parser.add_argument(
        '--dynamic_network',
        default=False,
        type=str2bool,
        nargs='?',
        help=
        'Creates a new network graph from pruned model, works with ResNet-101 only'
    )

    parser.add_argument('--use_test_as_train',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='use testing dataset instead of training')

    parser.add_argument('--pruning_mask_from',
                        default='',
                        type=str,
                        help='path to mask file precomputed')

    parser.add_argument(
        '--compute_flops',
        default=True,
        type=str2bool,
        nargs='?',
        help=
        'if True, will run dummy inference of batch 1 before training to get conv sizes'
    )

    # ============================END pruning added

    best_prec1 = 0
    global global_iteration
    global group_wd_optimizer
    global_iteration = 0

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    args.distributed = args.world_size > 1
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=0)

    device = torch.device("cuda" if use_cuda else "cpu")

    if args.model == "lenet3":
        model = LeNet(dataset=args.dataset)
    elif args.model == "vgg":
        model = vgg11_bn(pretrained=True)
    elif args.model == "resnet18":
        model = PreActResNet18()
    elif (args.model == "resnet50") or (args.model == "resnet50_noskip"):
        if args.dataset == "CIFAR10":
            model = PreActResNet50(dataset=args.dataset)
        else:
            from models.resnet import resnet50
            skip_gate = True
            if "noskip" in args.model:
                skip_gate = False

            if args.pruning_method not in [22, 40]:
                skip_gate = False
            model = resnet50(skip_gate=skip_gate)
    elif args.model == "resnet34":
        if not (args.dataset == "CIFAR10"):
            from models.resnet import resnet34
            model = resnet34()
    elif "resnet101" in args.model:
        if not (args.dataset == "CIFAR10"):
            from models.resnet import resnet101
            if args.dataset == "Imagenet":
                classes = 1000

            if "noskip" in args.model:
                model = resnet101(num_classes=classes, skip_gate=False)
            else:
                model = resnet101(num_classes=classes)

    elif args.model == "resnet20":
        if args.dataset == "CIFAR10":
            NotImplementedError(
                "resnet20 is not implemented in the current project")
            # from models.resnet_cifar import resnet20
            # model = resnet20()
    elif args.model == "resnet152":
        model = PreActResNet152()
    elif args.model == "densenet201_imagenet":
        from models.densenet_imagenet import DenseNet201
        model = DenseNet201(gate_types=['output_bn'], pretrained=True)
    elif args.model == "densenet121_imagenet":
        from models.densenet_imagenet import DenseNet121
        model = DenseNet121(gate_types=['output_bn'], pretrained=True)
    else:
        print(args.model, "model is not supported")

    # dataset loading section
    if args.dataset == "MNIST":
        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "CIFAR10":
        # Data loading code
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        if args.augment:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        kwargs = {'num_workers': 8, 'pin_memory': True}
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            '../data', train=True, download=True, transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   drop_last=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('../data', train=False, transform=transform_test),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "Imagenet":
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)
        else:
            train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        if args.use_test_as_train:
            train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                **kwargs)

        test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  **kwargs)

    ####end dataset preparation

    if args.dynamic_network:
        # attempts to load pruned model and modify it be removing pruned channels
        # works for resnet101 only
        if (len(args.load_model) > 0) and (args.dynamic_network):
            if os.path.isfile(args.load_model):
                load_model_pytorch(model, args.load_model, args.model)

            else:
                print("=> no checkpoint found at '{}'".format(args.load_model))
                exit()

        dynamic_network_change_local(model)

        # save the model
        log_save_folder = "%s" % args.name
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        if not os.path.exists("%s/models" % (log_save_folder)):
            os.makedirs("%s/models" % (log_save_folder))

        model_save_path = "%s/models/pruned.weights" % (log_save_folder)
        model_state_dict = model.state_dict()
        if args.save_models:
            save_checkpoint({'state_dict': model_state_dict},
                            False,
                            filename=model_save_path)

    print("model is defined")

    # aux function to get size of feature maps
    # First it adds hooks for each conv layer
    # Then runs inference with 1 image
    output_sizes = get_conv_sizes(args, model)

    if use_cuda and not args.mgpu:
        model = model.to(device)
    elif args.distributed:
        model.cuda()
        print(
            "\n\n WARNING: distributed pruning was not verified and might not work correctly"
        )
        model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.mgpu:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.to(device)

    print(
        "model is set to device: use_cuda {}, args.mgpu {}, agrs.distributed {}"
        .format(use_cuda, args.mgpu, args.distributed))

    weight_decay = args.wd
    if args.fixed_network:
        weight_decay = 0.0

    # remove updates from gate layers, because we want them to be 0 or 1 constantly
    if 1:
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

    total_size_params = sum(
        [np.prod(par.shape) for par in parameters_for_update])
    print("Total number of parameters, w/o usage of bn consts: ",
          total_size_params)

    optimizer = optim.SGD(parameters_for_update,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=weight_decay)

    if 1:
        # helping optimizer to implement group lasso (with very small weight that doesn't affect training)
        # will be used to calculate number of remaining flops and parameters in the network
        group_wd_optimizer = group_lasso_decay(
            parameters_for_update,
            group_lasso_weight=args.group_wd_coeff,
            named_parameters=parameters_for_update_named,
            output_sizes=output_sizes)

    cudnn.benchmark = True

    # define objective
    criterion = nn.CrossEntropyLoss()

    ###=======================added for pruning
    # logging part
    log_save_folder = "%s" % args.name
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    if not os.path.exists("%s/models" % (log_save_folder)):
        os.makedirs("%s/models" % (log_save_folder))

    train_writer = None
    if args.tensorboard:
        try:
            # tensorboardX v1.6
            train_writer = SummaryWriter(log_dir="%s" % (log_save_folder))
        except:
            # tensorboardX v1.7
            train_writer = SummaryWriter(logdir="%s" % (log_save_folder))

    time_point = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    textfile = "%s/log_%s.txt" % (log_save_folder, time_point)
    stdout = Logger(textfile)
    sys.stdout = stdout
    print(" ".join(sys.argv))

    # initializing parameters for pruning
    # we can add weights of different layers or we can add gates (multiplies output with 1, useful only for gradient computation)
    pruning_engine = None
    if args.pruning:
        pruning_settings = dict()
        if not (args.pruning_config is None):
            pruning_settings_reader = PruningConfigReader()
            pruning_settings_reader.read_config(args.pruning_config)
            pruning_settings = pruning_settings_reader.get_parameters()

        # overwrite parameters from config file with those from command line
        # needs manual entry here
        # user_specified = [key for key in vars(default_args).keys() if not (vars(default_args)[key]==vars(args)[key])]
        # argv_of_interest = ['pruning_threshold', 'pruning-momentum', 'pruning_step', 'prune_per_iteration',
        #                     'fixed_layer', 'start_pruning_after_n_iterations', 'maximum_pruning_iterations',
        #                     'starting_neuron', 'prune_neurons_max', 'pruning_method']

        has_attribute = lambda x: any([x in a for a in sys.argv])

        if has_attribute('pruning-momentum'):
            pruning_settings['pruning_momentum'] = vars(
                args)['pruning_momentum']
        if has_attribute('pruning-method'):
            pruning_settings['method'] = vars(args)['pruning_method']

        pruning_parameters_list = prepare_pruning_list(
            pruning_settings,
            model,
            model_name=args.model,
            pruning_mask_from=args.pruning_mask_from,
            name=args.name)
        print("Total pruning layers:", len(pruning_parameters_list))

        folder_to_write = "%s" % log_save_folder + "/"
        log_folder = folder_to_write

        pruning_engine = pytorch_pruning(pruning_parameters_list,
                                         pruning_settings=pruning_settings,
                                         log_folder=log_folder)

        pruning_engine.connect_tensorboard(train_writer)
        pruning_engine.dataset = args.dataset
        pruning_engine.model = args.model
        pruning_engine.pruning_mask_from = args.pruning_mask_from
        pruning_engine.load_mask()
        gates_to_params = connect_gates_with_parameters_for_flops(
            args.model, parameters_for_update_named)
        pruning_engine.gates_to_params = gates_to_params

    ###=======================end for pruning
    # loading model file
    if (len(args.load_model) > 0) and (not args.dynamic_network):
        if os.path.isfile(args.load_model):
            load_model_pytorch(model, args.load_model, args.model)
        else:
            print("=> no checkpoint found at '{}'".format(args.load_model))
            exit()

    if args.tensorboard and 0:
        if args.dataset == "CIFAR10":
            dummy_input = torch.rand(1, 3, 32, 32).to(device)
        elif args.dataset == "Imagenet":
            dummy_input = torch.rand(1, 3, 224, 224).to(device)

        train_writer.add_graph(model, dummy_input)

    for epoch in range(1, args.epochs + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(args, optimizer, epoch, args.zero_lr_for_epochs,
                             train_writer)

        if not args.run_test and not args.get_inference_time:
            train(args,
                  model,
                  device,
                  train_loader,
                  optimizer,
                  epoch,
                  criterion,
                  train_writer=train_writer,
                  pruning_engine=pruning_engine)

        if args.pruning:
            # skip validation error calculation and model saving
            if pruning_engine.method == 50: continue

        # evaluate on validation set
        prec1, _ = validate(args,
                            test_loader,
                            model,
                            device,
                            criterion,
                            epoch,
                            train_writer=train_writer)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        model_save_path = "%s/models/checkpoint.weights" % (log_save_folder)
        model_state_dict = model.state_dict()
        if args.save_models:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=model_save_path)
Exemplo n.º 29
0
def train():
    mkdirs(config.checkpoint_path, config.best_model_path, config.logs)
    # load data
    src1_train_dataloader_fake, src1_train_dataloader_real, \
    src2_train_dataloader_fake, src2_train_dataloader_real, \
    src3_train_dataloader_fake, src3_train_dataloader_real, \
    tgt_valid_dataloader = get_dataset(config.src1_data, config.src1_train_num_frames,
                                       config.src2_data, config.src2_train_num_frames,
                                       config.src3_data, config.src3_train_num_frames,
                                       config.tgt_data, config.tgt_test_num_frames, config.batch_size)

    best_model_ACC = 0.0
    best_model_HTER = 1.0
    best_model_ACER = 1.0
    best_model_AUC = 0.0
    # 0:loss, 1:top-1, 2:EER, 3:HTER, 4:ACER, 5:AUC, 6:threshold
    valid_args = [np.inf, 0, 0, 0, 0, 0, 0, 0]

    loss_classifier = AverageMeter()
    classifer_top1 = AverageMeter()

    net = DG_model(config.model).to(device)
    ad_net_real = Discriminator().to(device)
    ad_net_fake = Discriminator().to(device)

    log = Logger()
    log.open(config.logs + config.tgt_data + '_log_SSDG.txt', mode='a')
    log.write(
        "\n----------------------------------------------- [START %s] %s\n\n" %
        (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '-' * 51))
    print("Norm_flag: ", config.norm_flag)
    log.write('** start training target model! **\n')
    log.write(
        '--------|------------- VALID -------------|--- classifier ---|------ Current Best ------|--------------|\n'
    )
    log.write(
        '  iter  |   loss   top-1   HTER    AUC    |   loss   top-1   |   top-1   HTER    AUC    |    time      |\n'
    )
    log.write(
        '-------------------------------------------------------------------------------------------------------|\n'
    )
    start = timer()
    criterion = {
        'softmax': nn.CrossEntropyLoss().cuda(),
        'triplet': HardTripletLoss(margin=0.1, hardest=False).cuda()
    }
    optimizer_dict = [
        {
            "params": filter(lambda p: p.requires_grad, net.parameters()),
            "lr": config.init_lr
        },
        {
            "params": filter(lambda p: p.requires_grad,
                             ad_net_real.parameters()),
            "lr": config.init_lr
        },
    ]
    optimizer = optim.SGD(optimizer_dict,
                          lr=config.init_lr,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay)
    init_param_lr = []
    for param_group in optimizer.param_groups:
        init_param_lr.append(param_group["lr"])

    iter_per_epoch = 10

    src1_train_iter_real = iter(src1_train_dataloader_real)
    src1_iter_per_epoch_real = len(src1_train_iter_real)
    src2_train_iter_real = iter(src2_train_dataloader_real)
    src2_iter_per_epoch_real = len(src2_train_iter_real)
    src3_train_iter_real = iter(src3_train_dataloader_real)
    src3_iter_per_epoch_real = len(src3_train_iter_real)
    src1_train_iter_fake = iter(src1_train_dataloader_fake)
    src1_iter_per_epoch_fake = len(src1_train_iter_fake)
    src2_train_iter_fake = iter(src2_train_dataloader_fake)
    src2_iter_per_epoch_fake = len(src2_train_iter_fake)
    src3_train_iter_fake = iter(src3_train_dataloader_fake)
    src3_iter_per_epoch_fake = len(src3_train_iter_fake)

    max_iter = config.max_iter
    epoch = 1
    if (len(config.gpus) > 1):
        net = torch.nn.DataParallel(net).cuda()

    for iter_num in range(max_iter + 1):
        if (iter_num % src1_iter_per_epoch_real == 0):
            src1_train_iter_real = iter(src1_train_dataloader_real)
        if (iter_num % src2_iter_per_epoch_real == 0):
            src2_train_iter_real = iter(src2_train_dataloader_real)
        if (iter_num % src3_iter_per_epoch_real == 0):
            src3_train_iter_real = iter(src3_train_dataloader_real)
        if (iter_num % src1_iter_per_epoch_fake == 0):
            src1_train_iter_fake = iter(src1_train_dataloader_fake)
        if (iter_num % src2_iter_per_epoch_fake == 0):
            src2_train_iter_fake = iter(src2_train_dataloader_fake)
        if (iter_num % src3_iter_per_epoch_fake == 0):
            src3_train_iter_fake = iter(src3_train_dataloader_fake)
        if (iter_num != 0 and iter_num % iter_per_epoch == 0):
            epoch = epoch + 1
        param_lr_tmp = []
        for param_group in optimizer.param_groups:
            param_lr_tmp.append(param_group["lr"])

        net.train(True)
        ad_net_real.train(True)
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, epoch, init_param_lr,
                             config.lr_epoch_1, config.lr_epoch_2)
        ######### data prepare #########
        src1_img_real, src1_label_real = src1_train_iter_real.next()
        src1_img_real = src1_img_real.cuda()
        src1_label_real = src1_label_real.cuda()
        input1_real_shape = src1_img_real.shape[0]

        src2_img_real, src2_label_real = src2_train_iter_real.next()
        src2_img_real = src2_img_real.cuda()
        src2_label_real = src2_label_real.cuda()
        input2_real_shape = src2_img_real.shape[0]

        src3_img_real, src3_label_real = src3_train_iter_real.next()
        src3_img_real = src3_img_real.cuda()
        src3_label_real = src3_label_real.cuda()
        input3_real_shape = src3_img_real.shape[0]

        src1_img_fake, src1_label_fake = src1_train_iter_fake.next()
        src1_img_fake = src1_img_fake.cuda()
        src1_label_fake = src1_label_fake.cuda()
        input1_fake_shape = src1_img_fake.shape[0]

        src2_img_fake, src2_label_fake = src2_train_iter_fake.next()
        src2_img_fake = src2_img_fake.cuda()
        src2_label_fake = src2_label_fake.cuda()
        input2_fake_shape = src2_img_fake.shape[0]

        src3_img_fake, src3_label_fake = src3_train_iter_fake.next()
        src3_img_fake = src3_img_fake.cuda()
        src3_label_fake = src3_label_fake.cuda()
        input3_fake_shape = src3_img_fake.shape[0]

        input_data = torch.cat([
            src1_img_real, src1_img_fake, src2_img_real, src2_img_fake,
            src3_img_real, src3_img_fake
        ],
                               dim=0)

        source_label = torch.cat([
            src1_label_real, src1_label_fake, src2_label_real, src2_label_fake,
            src3_label_real, src3_label_fake
        ],
                                 dim=0)

        ######### forward #########
        classifier_label_out, feature = net(input_data, config.norm_flag)

        ######### single side adversarial learning #########
        input1_shape = input1_real_shape + input1_fake_shape
        input2_shape = input2_real_shape + input2_fake_shape
        feature_real_1 = feature.narrow(0, 0, input1_real_shape)
        feature_real_2 = feature.narrow(0, input1_shape, input2_real_shape)
        feature_real_3 = feature.narrow(0, input1_shape + input2_shape,
                                        input3_real_shape)
        feature_real = torch.cat(
            [feature_real_1, feature_real_2, feature_real_3], dim=0)
        discriminator_out_real = ad_net_real(feature_real)

        ######### unbalanced triplet loss #########
        real_domain_label_1 = torch.LongTensor(input1_real_shape,
                                               1).fill_(0).cuda()
        real_domain_label_2 = torch.LongTensor(input2_real_shape,
                                               1).fill_(0).cuda()
        real_domain_label_3 = torch.LongTensor(input3_real_shape,
                                               1).fill_(0).cuda()
        fake_domain_label_1 = torch.LongTensor(input1_fake_shape,
                                               1).fill_(1).cuda()
        fake_domain_label_2 = torch.LongTensor(input2_fake_shape,
                                               1).fill_(2).cuda()
        fake_domain_label_3 = torch.LongTensor(input3_fake_shape,
                                               1).fill_(3).cuda()
        source_domain_label = torch.cat([
            real_domain_label_1, fake_domain_label_1, real_domain_label_2,
            fake_domain_label_2, real_domain_label_3, fake_domain_label_3
        ],
                                        dim=0).view(-1)
        triplet = criterion["triplet"](feature, source_domain_label)

        ######### cross-entropy loss #########
        real_shape_list = []
        real_shape_list.append(input1_real_shape)
        real_shape_list.append(input2_real_shape)
        real_shape_list.append(input3_real_shape)
        real_adloss = Real_AdLoss(discriminator_out_real, criterion["softmax"],
                                  real_shape_list)
        cls_loss = criterion["softmax"](classifier_label_out.narrow(
            0, 0, input_data.size(0)), source_label)

        ######### backward #########
        total_loss = cls_loss + config.lambda_triplet * triplet + config.lambda_adreal * real_adloss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_classifier.update(cls_loss.item())
        acc = accuracy(classifier_label_out.narrow(0, 0, input_data.size(0)),
                       source_label,
                       topk=(1, ))
        classifer_top1.update(acc[0])
        print('\r', end='', flush=True)
        print(
            '  %4.1f  |  %5.3f  %6.3f  %6.3f  %6.3f  |  %6.3f  %6.3f  |  %6.3f  %6.3f  %6.3f  | %s'
            % ((iter_num + 1) / iter_per_epoch, valid_args[0], valid_args[6],
               valid_args[3] * 100, valid_args[4] * 100, loss_classifier.avg,
               classifer_top1.avg, float(best_model_ACC),
               float(best_model_HTER * 100), float(
                   best_model_AUC * 100), time_to_str(timer() - start, 'min')),
            end='',
            flush=True)

        if (iter_num != 0 and (iter_num + 1) % iter_per_epoch == 0):
            # 0:loss, 1:top-1, 2:EER, 3:HTER, 4:AUC, 5:threshold, 6:ACC_threshold
            valid_args = eval(tgt_valid_dataloader, net, config.norm_flag)
            # judge model according to HTER
            is_best = valid_args[3] <= best_model_HTER
            best_model_HTER = min(valid_args[3], best_model_HTER)
            threshold = valid_args[5]
            if (valid_args[3] <= best_model_HTER):
                best_model_ACC = valid_args[6]
                best_model_AUC = valid_args[4]

            save_list = [
                epoch, valid_args, best_model_HTER, best_model_ACC,
                best_model_ACER, threshold
            ]
            save_checkpoint(save_list, is_best, net, config.gpus,
                            config.checkpoint_path, config.best_model_path)
            print('\r', end='', flush=True)
            log.write(
                '  %4.1f  |  %5.3f  %6.3f  %6.3f  %6.3f  |  %6.3f  %6.3f  |  %6.3f  %6.3f  %6.3f  | %s   %s'
                %
                ((iter_num + 1) / iter_per_epoch, valid_args[0], valid_args[6],
                 valid_args[3] * 100, valid_args[4] * 100, loss_classifier.avg,
                 classifer_top1.avg, float(best_model_ACC),
                 float(best_model_HTER * 100), float(best_model_AUC * 100),
                 time_to_str(timer() - start, 'min'), param_lr_tmp[0]))
            log.write('\n')
            time.sleep(0.01)
Exemplo n.º 30
0
args.model = 'Bayes' + args.model

print('Resume training from %s' % fnames[0])
checkpoint = torch.load(fnames[0])
model.load_state_dict(checkpoint, strict=False)

columns = [
    'ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time', 'nll', 'kl',
    'te_ac', 'te_nll'
]

for epoch in range(start_epoch, args.epochs):
    time_ep = time.time()

    lr = schedule(epoch)
    utils.adjust_learning_rate(optimizer1, lr)
    optvar = optimizer2
    train_res = utils.train_epoch_vi(
        loaders['train'],
        model,
        criterion,
        optimizer1,
        args.beta,  # if epoch > args.epochs/2 else 0,
        optvar)
    test_res = {'loss': None, 'accuracy': None}

    test_res = utils.eval(loaders['test'], model, criterion)

    time_ep = time.time() - time_ep
    values = [
        epoch + 1, lr, train_res['loss'], train_res['accuracy'],