Example #1
0
    model = None
    if args.model_type.lower() == 'protonet':
        model = ProtoNet(args)
    elif args.model_type.lower() == 'hypnet':
        model = HypNet(args)
    elif args.model_type.lower() == 'protonetwithhyperbolic':
        model = ProtoNetWithHyperbolic(args)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.lr_decay:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.step_size, gamma=args.gamma)

        # load pre-trained model (no FC weights)
    model_dict = model.state_dict()
    if args.init_weights is not None:
        pretrained_dict = torch.load(args.init_weights)['params']
        # remove weights for FC
        pretrained_dict = {
            'encoder.' + k: v
            for k, v in pretrained_dict.items()
        }
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        print(pretrained_dict.keys())
        model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
Example #2
0
def main():
    global args, best_acc1, device

    # Init seed
    np.random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)

    if args.dataset == 'omniglot':
        train_loader, val_loader = get_dataloader(args, 'trainval', 'test')
        input_dim = 1
    else:
        train_loader, val_loader = get_dataloader(args, 'train', 'val')
        input_dim = 3

    if args.model == 'protonet':
        model = ProtoNet(input_dim).to(device)
        print("ProtoNet loaded")
    else:
        model = ResNet(input_dim).to(device)
        print("ResNet loaded")

    criterion = PrototypicalLoss().to(device)

    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    cudnn.benchmark = True

    if args.resume:
        try:
            checkpoint = torch.load(
                sorted(glob(f'{args.log_dir}/checkpoint_*.pth'), key=len)[-1])
        except Exception:
            checkpoint = torch.load(args.log_dir + '/model_best.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']

        print(f"load checkpoint {args.exp_name}")
    else:
        start_epoch = 1

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer=optimizer,
        gamma=args.lr_scheduler_gamma,
        step_size=args.lr_scheduler_step)

    print(
        f"model parameter : {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    for epoch in range(start_epoch, args.epochs + 1):

        train_loss = train(train_loader, model, optimizer, criterion, epoch)

        is_test = False if epoch % args.test_iter else True
        if is_test or epoch == args.epochs or epoch == 1:

            val_loss, acc1 = validate(val_loader, model, criterion, epoch)

            if acc1 >= best_acc1:
                is_best = True
                best_acc1 = acc1
            else:
                is_best = False

            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer_state_dict': optimizer.state_dict(),
                }, is_best, args)

            if is_best:
                writer.add_scalar("BestAcc", acc1, epoch)

            print(
                f"[{epoch}/{args.epochs}] {train_loss:.3f}, {val_loss:.3f}, {acc1:.3f}, # {best_acc1:.3f}"
            )

        else:
            print(f"[{epoch}/{args.epochs}] {train_loss:.3f}")

        scheduler.step()

    writer.close()