Esempio n. 1
0
def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = SupConLoss(temperature=opt.temp)

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion
Esempio n. 2
0
def main(opt):
    opt = setup_environment(opt)
    graph = Graph("coco")

    # Dataset
    transform = transforms.Compose([
        MirrorPoses(opt.mirror_probability),
        FlipSequence(opt.flip_probability),
        RandomSelectSequence(opt.sequence_length),
        ShuffleSequence(opt.shuffle),
        PointNoise(std=opt.point_noise_std),
        JointNoise(std=opt.joint_noise_std),
        MultiInput(graph.connect_joint, opt.use_multi_branch),
        ToTensor()
    ], )

    dataset_class = dataset_factory(opt.dataset)
    dataset = dataset_class(
        opt.train_data_path,
        train=True,
        sequence_length=opt.sequence_length,
        transform=TwoNoiseTransform(transform),
    )

    dataset_valid = dataset_class(
        opt.valid_data_path,
        sequence_length=opt.sequence_length,
        transform=transforms.Compose([
            SelectSequenceCenter(opt.sequence_length),
            MultiInput(graph.connect_joint, opt.use_multi_branch),
            ToTensor()
        ]),
    )

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        num_workers=opt.num_workers,
        pin_memory=True,
        shuffle=True,
    )

    val_loader = torch.utils.data.DataLoader(
        dataset_valid,
        batch_size=opt.batch_size_validation,
        num_workers=opt.num_workers,
        pin_memory=True,
    )

    # Model & criterion
    model, model_args = get_model_resgcn(graph, opt)
    criterion = SupConLoss(temperature=opt.temp)

    print("# parameters: ", count_parameters(model))

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, opt.gpus)

    if opt.cuda:
        model.cuda()
        criterion.cuda()

    # Trainer
    optimizer, scheduler, scaler = get_trainer(model, opt, len(train_loader))

    # Load checkpoint or weights
    load_checkpoint(model, optimizer, scheduler, scaler, opt)

    # Tensorboard
    writer = SummaryWriter(log_dir=opt.tb_path)

    sample_input = torch.zeros(opt.batch_size, model_args["num_input"],
                               model_args["num_channel"], opt.sequence_length,
                               graph.num_node).cuda()
    writer.add_graph(model, input_to_model=sample_input)

    best_acc = 0
    loss = 0
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        # train for one epoch
        time1 = time.time()
        loss = train(train_loader, model, criterion, optimizer, scheduler,
                     scaler, epoch, opt)

        time2 = time.time()
        print(f"epoch {epoch}, total time {time2 - time1:.2f}")

        # tensorboard logger
        writer.add_scalar("loss/train", loss, epoch)
        writer.add_scalar("learning_rate", optimizer.param_groups[0]["lr"],
                          epoch)

        # evaluation
        result, accuracy_avg, sub_accuracies, dataframe = evaluate(
            val_loader, model, opt.evaluation_fn, use_flip=True)
        writer.add_text("accuracy/validation", dataframe.to_markdown(), epoch)
        writer.add_scalar("accuracy/validation", accuracy_avg, epoch)
        for key, sub_accuracy in sub_accuracies.items():
            writer.add_scalar(f"accuracy/validation/{key}", sub_accuracy,
                              epoch)

        print(f"epoch {epoch}, avg accuracy {accuracy_avg:.4f}")
        is_best = accuracy_avg > best_acc
        if is_best:
            best_acc = accuracy_avg

        if opt.tune:
            tune.report(accuracy=accuracy_avg)

        if epoch % opt.save_interval == 0 or (
                is_best and epoch > opt.save_best_start * opt.epochs):
            save_file = os.path.join(
                opt.save_folder,
                f"ckpt_epoch_{'best' if is_best else epoch}.pth")
            save_model(model, optimizer, scheduler, scaler, opt, opt.epochs,
                       save_file)

    # save the last model
    save_file = os.path.join(opt.save_folder, "last.pth")
    save_model(model, optimizer, scheduler, scaler, opt, opt.epochs, save_file)

    log_hyperparameter(writer, opt, best_acc, loss)

    print(f"best accuracy: {best_acc*100:.2f}")