예제 #1
0
def main():
  opt = TrainParser().parse()

  # set up logger
  utils.set_logger(opt=opt, filename='train.log', filemode='w')

  if opt.seed:
    utils.make_deterministic(opt.seed)

  model = models.get_model(opt)
  model = model.to(device)
  if opt.multi_gpu and device == 'cuda':
    model = torch.nn.DataParallel(model)
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = utils.get_optimizer(opt, params=model.parameters(), lr=opt.lr,
    weight_decay=opt.weight_decay, momentum=opt.momentum)
  trainer = AdversarialTrainer if opt.adversarial else Trainer
  trainer = trainer(opt=opt, model=model, optimizer=optimizer,
                      val_metric_name='acc (%)', val_metric_obj='max')
  if hasattr(model, 'update_centers'):
    trainer.update_centers_eval = lambda: utils.update_centers_eval(model)
  loader = datasets.get_dataloaders(opt)

  # TODO a hacky way to load some dummy validation data
  opt.is_train = False
  val_loader = datasets.get_dataloaders(opt)
  opt.is_train = True

  if opt.load_model:
    trainer.load()

  # save init model
  trainer.save(
    epoch=trainer.start_epoch - 1,
    val_metric_value=trainer.best_val_metric,
    force_save=True
  )

  utils.update_centers_eval(model)
  train(opt, n_epochs=opt.n_epochs, trainer=trainer, loader=loader, val_loader=val_loader,
        criterion=criterion, device=device)
예제 #2
0
파일: visualize.py 프로젝트: skn123/kerNET
    fig.savefig(save_name, bbox_inches='tight')


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    opt = TestParser().parse()

    # set up logger
    utils.set_logger(opt=opt, filename='test.log', filemode='a')
    logger = logging.getLogger()

    if opt.adversarial:
        from cleverhans.future.torch.attacks import \
            fast_gradient_method, projected_gradient_descent

    net = models.get_model(opt)
    net = net.to(device)
    if opt.multi_gpu and device == 'cuda':
        net = torch.nn.DataParallel(net)
    net.load_state_dict(
        torch.load(os.path.join(opt.checkpoint_dir, 'net.pth'))['state_dict'])

    if hasattr(net, 'update_centers'):
        utils.update_centers_eval(net)

    net_head = torch.nn.Sequential(*list(net.children())[:-1])

    loader = datasets.get_dataloaders(opt)
    raw_data, labels, activations = test(opt, net, loader)

    idx = np.random.permutation(10000)
예제 #3
0
def main():
    opt = TrainParser().parse()

    # set up logger
    utils.set_logger(opt=opt, filename='train.log', filemode='w')

    if opt.seed:
        utils.make_deterministic(opt.seed)
    loader = datasets.get_dataloaders(opt)

    # TODO a hacky way to load some dummy validation data
    opt.is_train = False
    val_loader = datasets.get_dataloaders(opt)
    opt.is_train = True

    model = models.get_model(opt)
    model = model.to(device)
    modules, params = model.split(n_parts=opt.n_parts, mode=opt.split_mode)

    trainer_cls = AdversarialTrainer if opt.adversarial else Trainer

    output_layer = list(model.children())[-1]
    hidden_criterion = getattr(losses, opt.hidden_objective)(output_layer.phi,
                                                             opt.n_classes)
    output_criterion = torch.nn.CrossEntropyLoss()

    optimizers, trainers = [], []
    for i in range(1, opt.n_parts + 1):
        optimizers.append(
            utils.get_optimizer(opt,
                                params=params[i - 1],
                                lr=getattr(opt, 'lr{}'.format(i)),
                                weight_decay=getattr(
                                    opt, 'weight_decay{}'.format(i)),
                                momentum=getattr(opt, 'momentum{}'.format(i))))
        trainer = trainer_cls(
            opt=opt,
            model=modules[i - 1],
            set_eval=modules[i - 2] if i > 1 else None,
            optimizer=optimizers[i - 1],
            val_metric_name=opt.hidden_objective if i < opt.n_parts else 'acc',
            val_metric_obj='max')
        trainers.append(trainer)

        if opt.load_model:
            if i < opt.n_parts:  # load hidden layer
                trainers[i - 1].load('net_part{}.pth'.format(i))
            else:  # load output layer
                trainers[i - 1].load('net.pth')

    # save init model
    trainers[0].save(epoch=trainers[0].start_epoch - 1,
                     val_metric_value=trainers[0].best_val_metric,
                     model_name='net_part{}.pth'.format(1),
                     force_save=True)
    # train the first hidden module
    train_hidden(opt,
                 n_epochs=opt.n_epochs1,
                 trainer=trainers[0],
                 loader=loader,
                 val_loader=val_loader,
                 criterion=hidden_criterion,
                 part_id=1,
                 device=device)

    # train other hidden modules
    for i in range(2, opt.n_parts):
        # save init model
        trainers[i - 1].save(epoch=trainers[i - 1].start_epoch - 1,
                             val_metric_value=trainers[i - 1].best_val_metric,
                             model_name='net_part{}.pth'.format(i),
                             force_save=True)
        # prepare centers
        utils.update_centers_eval(model)
        # exclude certain network part(s) from the graph to make things faster
        utils.exclude_during_backward(modules[i - 2])
        train_hidden(opt,
                     n_epochs=getattr(opt, 'n_epochs{}'.format(i)),
                     trainer=trainers[i - 1],
                     loader=loader,
                     val_loader=val_loader,
                     criterion=hidden_criterion,
                     part_id=i,
                     device=device)

    # save init model
    trainers[-1].save(epoch=trainers[-1].start_epoch - 1,
                      val_metric_value=trainers[-1].best_val_metric,
                      model_name='net.pth',
                      force_save=True)
    # train output layer
    utils.update_centers_eval(model)
    utils.exclude_during_backward(modules[-2])
    train_output(opt,
                 n_epochs=getattr(opt, 'n_epochs{}'.format(opt.n_parts)),
                 trainer=trainers[-1],
                 loader=loader,
                 val_loader=val_loader,
                 criterion=output_criterion,
                 part_id=opt.n_parts,
                 device=device)
    utils.include_during_backward(modules[-2])