Exemplo n.º 1
0
def main(**kwargs):
    opt = Config()

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    if opt.vis:
        vis = Visualizer(opt.env)
    else:
        vis = None

    init_loss_file(opt)
    train_path, valid_path, test_path = init_file_path(opt)

    # random_state = random.randint(1, 50)
    # print("random_state:", random_state)
    train_dataset = KTData(train_path, opt='None')
    valid_dataset = KTData(valid_path, opt='None')
    test_dataset = KTData(test_path, opt='None')

    # print(train_path, valid_path, test_path)
    print(len(train_dataset), len(valid_dataset), len(test_dataset))

    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers,
                             drop_last=True, collate_fn=myutils.collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers,
                             drop_last=True, collate_fn=myutils.collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers,
                             drop_last=True, collate_fn=myutils.collate_fn)

    if opt.model_name == "CNN":
        model = CNN(opt.input_dim, opt.embed_dim, opt.hidden_dim, opt.num_layers, opt.output_dim, opt.batch_size, opt.device)
    elif opt.model_name == "CNN_3D":
        model = CNN_3D(opt.input_dim, opt.embed_dim, opt.hidden_dim, opt.num_layers, opt.output_dim, opt.batch_size, opt.device)
    else:
        model = RNN_DKT(opt.input_dim, opt.embed_dim, opt.hidden_dim, opt.num_layers, opt.output_dim, opt.batch_size, opt.device)

    lr = opt.lr
    last_epoch = -1
    previous_loss = 1e10

    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=lr,
        weight_decay=opt.weight_decay,
        betas=(0.9, 0.99)
    )
    if opt.model_path:
        map_location = lambda storage, loc: storage
        checkpoint = torch.load(opt.model_path, map_location=map_location)
        model.load_state_dict(checkpoint["model"])
        last_epoch = checkpoint["epoch"]
        lr = checkpoint["lr"]
        optimizer.load_state_dict(checkpoint["optimizer"])

    model = model.to(opt.device)

    loss_result = {}
    auc_resilt = {}
    # START TRAIN
    for epoch in range(opt.max_epoch):
        if epoch < last_epoch:
            continue
        if opt.model_name == "CNN_3D":
            train_loss_meter, train_auc_meter, train_loss_list = train.train_3d(opt, vis, model, train_loader, epoch, lr,
                                                                             optimizer)
            val_loss_meter, val_auc_meter, val_loss_list = train.valid_3d(opt, vis, model, valid_loader, epoch)
            test_loss_meter, test_auc_meter, test_loss_list = test.test_3d(opt, vis, model, test_loader, epoch)
        else:
            train_loss_meter, train_auc_meter, train_loss_list = train.train_3d(opt, vis, model, train_loader, epoch, lr, optimizer)
            val_loss_meter, val_auc_meter, val_loss_list = train.valid_3d(opt, vis, model, valid_loader, epoch)
            test_loss_meter, test_auc_meter, test_loss_list = test.test_3d(opt, vis, model, test_loader, epoch)

        loss_result["train_loss"] = train_loss_meter.value()[0]
        auc_resilt["train_auc"] = train_auc_meter.value()[0]
        loss_result["val_loss"] = val_loss_meter.value()[0]
        auc_resilt["val_auc"] = val_auc_meter.value()[0]
        loss_result["test_loss"] = test_loss_meter.value()[0]
        auc_resilt["test_auc"] = test_auc_meter.value()[0]

        for k, v in loss_result.items():
            print("epoch:{epoch}, {k}:{v:.5f}".format(epoch=epoch, k=k, v=v))
            if opt.vis:
                vis.line(X=np.array([epoch]), Y=np.array([v]),
                         win="loss",
                         opts=dict(title="loss", showlegend=True),
                         name = k,
                         update='append')
        for k, v in auc_resilt.items():
            print("epoch:{epoch}, {k}:{v:.5f}".format(epoch=epoch, k=k, v=v))
            if opt.vis:
                vis.line(X=np.array([epoch]), Y=np.array([v]),
                         win="auc",
                         opts=dict(title="auc", showlegend=True),
                         name = k,
                         update='append')

        # TODO 每个epoch结束后把loss写入文件
        myutils.save_loss_file(opt, epoch, train_loss_list, val_loss_list, test_loss_list)

        # TODO 每save_every个epoch结束后保存模型参数+optimizer参数
        if epoch % opt.save_every == 0:
            myutils.save_model_weight(opt, model, optimizer, epoch, lr)

        # TODO 做lr_decay
        lr = myutils.adjust_lr(opt, optimizer, epoch)

    # TODO 结束的时候保存final模型参数
    myutils.save_model_weight(opt, model, optimizer, epoch, lr, is_final=True)
Exemplo n.º 2
0
totals = [total1, total2, total3, total4]

# Load length of data manually
lengths = [9000, 9000, 9000, 9000]

# Set the label manually (truth)
True_labels = [3, 2, 0, 0]

# Threshold, under which we concat
Threshold = 0.001

model = CNN(num_classes=4)
model = nn.DataParallel(model)
model = model.to(device)
model.load_state_dict(
    torch.load('saved_model/best_model.pth',
               map_location=lambda storage, loc: storage))
for param in model.parameters():
    param.requires_grad = False

total_numbers = [0.0, 0.0, 0.0, 0.0]
adv_numbers = [0.0, 0.0, 0.0, 0.0]

# Concat and test
model.eval()

for i in range(len(totals)):
    total = torch.from_numpy(totals[i]).float().to(device)
    length = lengths[i]
    # kj and jk so labels doubled
    true_label = torch.tensor([True_labels[i], True_labels[i]]).to(device)
Exemplo n.º 3
0
def run_train_valid(opt, vis):
    print(opt.__dict__)
    train_path, valid_path, test_path = init_file_path(opt)

    train_dataset = KTData(train_path, opt='None')
    valid_dataset = KTData(valid_path, opt='None')

    print(train_path, valid_path)
    print(len(train_dataset), len(valid_dataset))

    train_loader = DataLoader(train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              drop_last=True,
                              collate_fn=myutils.collate_fn)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              drop_last=True,
                              collate_fn=myutils.collate_fn)

    if opt.model_name == "CNN":
        model = CNN(opt.input_dim, opt.embed_dim, opt.hidden_dim,
                    opt.num_layers, opt.output_dim, opt.batch_size, opt.device)
    elif opt.model_name == "CNN_3D":
        model = CNN_3D(opt.input_dim, opt.embed_dim, opt.hidden_dim,
                       opt.num_layers, opt.output_dim, opt.batch_size,
                       opt.device)
    else:
        model = RNN_DKT(opt.input_dim, opt.embed_dim, opt.hidden_dim,
                        opt.num_layers, opt.output_dim, opt.batch_size,
                        opt.device)

    lr = opt.lr
    last_epoch = -1
    previous_loss = 1e10

    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=lr,
                                 weight_decay=opt.weight_decay,
                                 betas=(0.9, 0.99))
    if opt.model_path:
        map_location = lambda storage, loc: storage
        checkpoint = torch.load(opt.model_path, map_location=map_location)
        model.load_state_dict(checkpoint["model"])
        last_epoch = checkpoint["epoch"]
        lr = checkpoint["lr"]
        optimizer.load_state_dict(checkpoint["optimizer"])

    model = model.to(opt.device)

    train_loss_list = []
    train_auc_list = []
    valid_loss_list = []
    valid_auc_list = []
    # START TRAIN
    for epoch in range(opt.max_epoch):
        if epoch < last_epoch:
            continue

        train_loss_meter, train_auc_meter, _ = train.train_3d(
            opt, vis, model, train_loader, epoch, lr, optimizer)
        val_loss_meter, val_auc_meter, _ = train.valid_3d(
            opt, vis, model, valid_loader, epoch)

        print("epoch: {}, train_auc: {}, val_auc: {}".format(
            epoch,
            train_auc_meter.value()[0],
            val_auc_meter.value()[0]))

        train_loss_list.append(train_loss_meter.value()[0])
        train_auc_list.append(train_auc_meter.value()[0])

        valid_loss_list.append(val_loss_meter.value()[0])
        valid_auc_list.append(val_auc_meter.value()[0])

        # TODO 每save_every个epoch结束后保存模型参数+optimizer参数
        if epoch % opt.save_every == 0:
            myutils.save_model_weight(opt,
                                      model,
                                      optimizer,
                                      epoch,
                                      lr,
                                      is_CV=True)

        # TODO 做lr_decay
        lr = myutils.adjust_lr(opt, optimizer, epoch)

    # TODO 结束的时候保存final模型参数
    myutils.save_model_weight(opt,
                              model,
                              optimizer,
                              epoch,
                              lr,
                              is_final=True,
                              is_CV=True)

    return train_loss_list, train_auc_list, valid_loss_list, valid_auc_list