예제 #1
0
def main():
    global args, best_acc1, device

    criterion = PrototypicalLoss().to(device)

    cudnn.benchmark = True

    runs_path = glob('runs/*')
    max_len_exp = max([len(x) for x in runs_path]) + 2
    print(f"|{'Experiment':^{max_len_exp}}|{'Loss':^17}|{'ACC':^17}|")

    except_list = []
    pl_mi = u"\u00B1"

    for exp in glob('runs/*'):
        checkpoint, args = None, None
        files = glob(exp + '/*')
        for file in files:
            if file.endswith('model_best.pth'):
                checkpoint = torch.load(os.path.abspath(file))
            elif file.endswith('.json'):
                params = json.load(open(os.path.abspath(file)))
                args = SimpleNamespace(**params)

        if checkpoint is None or args is None:
            except_list.append(f"checkpoint and params are not exist in {exp}")
            continue

        if args.dataset == 'omniglot':
            test_loader = get_dataloader(args, 'test')
        else:
            test_loader = get_dataloader(args, 'val')

        input_dim = 1 if args.dataset == 'omniglot' else 3
        if args.model == 'protonet':
            model = ProtoNet(input_dim).to(device)
        else:
            model = ResNet(input_dim).to(device)

        model.load_state_dict(checkpoint['model_state_dict'])
        best_acc1 = checkpoint['best_acc1']

        loss_list, acc_list = test(test_loader, model, criterion)

        loss, loss_moe = margin_of_error(loss_list)
        acc, acc_moe = margin_of_error(acc_list)

        loss_string = f'{loss:.3f} {pl_mi} {loss_moe:.3f}'
        acc_string = f'{acc:.3f} {pl_mi} {acc_moe:.3f}'

        print(f"|{exp:^{max_len_exp}}|{loss_string:^16}|{acc_string:^16}|")

    if len(except_list):
        pp(except_list)
예제 #2
0
        # load pre-trained model (no FC weights)
    model_dict = model.state_dict()
    if args.init_weights is not None:
        pretrained_dict = torch.load(args.init_weights)['params']
        # remove weights for FC
        pretrained_dict = {
            'encoder.' + k: v
            for k, v in pretrained_dict.items()
        }
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        print(pretrained_dict.keys())
        model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        model = model.cuda()

    def save_model(name):
        torch.save(dict(params=model.state_dict()),
                   osp.join(args.save_path, name + '.pth'))

    trlog = {}
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    if args.lr_decay:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.step_size, gamma=args.gamma
        )

        # load pre-trained model (no FC weights)
    model_dict = model.state_dict()
    if args.init_weights is not None:
        pretrained_dict = torch.load(args.init_weights)["params"]
        # remove weights for FC
        pretrained_dict = {"encoder." + k: v for k, v in pretrained_dict.items()}
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        print(pretrained_dict.keys())
        model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        model = model.cuda()
    # model = nn.DataParallel(model)

    def save_model(name):
        torch.save(
            dict(params=model.state_dict()), osp.join(args.save_path, name + ".pth")
        )

    trlog = {}
    trlog["args"] = vars(args)
    trlog["train_loss"] = []
    trlog["val_loss"] = []
예제 #4
0
        l3_n_infer=args.l3_n_infer,
    )

    model = model.cuda()

    if args.save_iter != -1:
        modelfile = get_assigned_file(args.checkpoint_dir, args.save_iter)
    else:
        modelfile = get_best_file(args.checkpoint_dir)

    if modelfile is not None:
        tmp = torch.load(modelfile)
        model.load_state_dict(
            tmp["state"],
            # If language was used for pretraining, ignore
            # the language model component here. If we want to use language,
            # make sure the model is loaded
            strict=args.lsl,
        )

        if args.save_embeddings:
            if args.lsl:
                weights = model.language_model.embedding.weight.detach().cpu().numpy()
                vocab_srt = sorted(list(vocab.items()), key=lambda x: x[1])
                vocab_srt = [v[0] for v in vocab_srt]
                with open(args.embeddings_file, "w") as fout:
                    fout.write("\n".join(vocab_srt))
                    fout.write("\n")
                np.savetxt(args.embeddings_metadata, weights, fmt="%f", delimiter="\t")
                sys.exit(0)
예제 #5
0
def main():
    global args, best_acc1, device

    # Init seed
    np.random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)

    if args.dataset == 'omniglot':
        train_loader, val_loader = get_dataloader(args, 'trainval', 'test')
        input_dim = 1
    else:
        train_loader, val_loader = get_dataloader(args, 'train', 'val')
        input_dim = 3

    if args.model == 'protonet':
        model = ProtoNet(input_dim).to(device)
        print("ProtoNet loaded")
    else:
        model = ResNet(input_dim).to(device)
        print("ResNet loaded")

    criterion = PrototypicalLoss().to(device)

    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    cudnn.benchmark = True

    if args.resume:
        try:
            checkpoint = torch.load(
                sorted(glob(f'{args.log_dir}/checkpoint_*.pth'), key=len)[-1])
        except Exception:
            checkpoint = torch.load(args.log_dir + '/model_best.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']

        print(f"load checkpoint {args.exp_name}")
    else:
        start_epoch = 1

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer=optimizer,
        gamma=args.lr_scheduler_gamma,
        step_size=args.lr_scheduler_step)

    print(
        f"model parameter : {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    for epoch in range(start_epoch, args.epochs + 1):

        train_loss = train(train_loader, model, optimizer, criterion, epoch)

        is_test = False if epoch % args.test_iter else True
        if is_test or epoch == args.epochs or epoch == 1:

            val_loss, acc1 = validate(val_loader, model, criterion, epoch)

            if acc1 >= best_acc1:
                is_best = True
                best_acc1 = acc1
            else:
                is_best = False

            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer_state_dict': optimizer.state_dict(),
                }, is_best, args)

            if is_best:
                writer.add_scalar("BestAcc", acc1, epoch)

            print(
                f"[{epoch}/{args.epochs}] {train_loss:.3f}, {val_loss:.3f}, {acc1:.3f}, # {best_acc1:.3f}"
            )

        else:
            print(f"[{epoch}/{args.epochs}] {train_loss:.3f}")

        scheduler.step()

    writer.close()
예제 #6
0
        lang_supervision=args.lang_supervision,
        l3=args.l3,
        l3_model=l3_model,
        l3_n_infer=args.l3_n_infer)

    model = model.cuda()

    os.makedirs(args.checkpoint_dir, exist_ok=True)

    start_epoch = args.start_epoch
    stop_epoch = args.stop_epoch

    if args.resume:
        resume_file = get_resume_file(args.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp["epoch"] + 1
            model.load_state_dict(tmp["state"])

    metrics_fname = "metrics_{}.json".format(args.n)

    train(
        base_loader,
        val_loader,
        model,
        start_epoch,
        stop_epoch,
        args,
        metrics_fname=metrics_fname,
    )