Exemplo n.º 1
0
def load_model(model_name: str):
    base = model_name.split("_")[0]
    model = get_res_pre_trained(base)
    save_path = f"/opt/ml/code/checkpoint/last/{model_name}"

    model.load_state_dict(torch.load(save_path))
    print(f"Loaded model:{model_name}")

    return model
Exemplo n.º 2
0
def main(args):
    wandb.init(project="stage-1", reinit=True)
    wandb.run.name = args.MODEL
    wandb.config.update(args)

    args = wandb.config

    train_loader, val_loader = get_loader(args.BATCH_SIZE)
    print("Get loader")
    model = get_res_pre_trained(args.MODEL).to(args.device)
    print("Load model")

    wandb.watch(model)

    criterion = create_criterion(args.LOSS)
    optimizer = optim.Adam(model.parameters(), lr=args.LEARNING_RATE)
    print("Run")
    run(args, model, criterion, optimizer, train_loader, val_loader)
def load_model(args):
    model = get_res_pre_trained(model_name=args.MODEL)
    model.load_state_dict(torch.load(args.MODEL_PATH))
    model = model.to(args.device)
    return model
Exemplo n.º 4
0
            )
            # print(f'epoch:[{epoch+1}/{EPOCHS}] loss:[{loss_val_avg:.3f}]')
            # print(loss_val_sum)
        if (loss_val_sum < pre_loss) or epoch == (EPOCHS - 1):
            save_model(model, cfg.MODEL, epoch, loss_val_avg)

    print("Training Done !")


if __name__ == "__main__":
    print("PyTorch version:[%s]." % (torch.__version__))
    print("This code use [%s]." % (cfg.device))

    train_loader, eval_loader = get_loader(cfg.BATCH_SIZE)

    # print(cfg.EPOCHS)
    # error : num_samples should be a positive integer value, but got num_samples=0

    model = get_res_pre_trained(cfg.MODEL).to(cfg.device)
    # print(model)
    # images, labels = next(iter(train_loader))
    # print(f'images shape: {images.shape}')
    # print(f'labels shape: {labels.shape}')
    # print(images)
    # print(labels)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE)

    train_model(model, train_loader, eval_loader, cfg.EPOCHS, cfg.device)