Ejemplo n.º 1
0
def load_model(model_path, pretrained=True):

    # NOTE cannot use bondnet.utils.yaml_load, which uses the safe_loader.
    # see: https://github.com/yaml/pyyaml/issues/266
    with open(model_path.joinpath("train_args.yaml"), "r") as f:
        model_args = yaml.load(f, Loader=yaml.Loader)

    model = GatedGCNReactionNetwork(
        in_feats=model_args.feature_size,
        embedding_size=model_args.embedding_size,
        gated_num_layers=model_args.gated_num_layers,
        gated_hidden_size=model_args.gated_hidden_size,
        gated_num_fc_layers=model_args.gated_num_fc_layers,
        gated_graph_norm=model_args.gated_graph_norm,
        gated_batch_norm=model_args.gated_batch_norm,
        gated_activation=model_args.gated_activation,
        gated_residual=model_args.gated_residual,
        gated_dropout=model_args.gated_dropout,
        num_lstm_iters=model_args.num_lstm_iters,
        num_lstm_layers=model_args.num_lstm_layers,
        set2set_ntypes_direct=model_args.set2set_ntypes_direct,
        fc_num_layers=model_args.fc_num_layers,
        fc_hidden_size=model_args.fc_hidden_size,
        fc_batch_norm=model_args.fc_batch_norm,
        fc_activation=model_args.fc_activation,
        fc_dropout=model_args.fc_dropout,
        outdim=1,
        conv="GatedGCNConv",
    )

    if pretrained:
        load_checkpoints(
            {"model": model},
            map_location=torch.device("cpu"),
            filename=model_path.joinpath("checkpoint.pkl"),
        )

    return model
Ejemplo n.º 2
0
def main_worker(gpu, world_size, args):
    global best
    args.gpu = gpu

    if not args.distributed or (args.distributed and args.gpu == 0):
        print("\n\nStart training at:", datetime.now())

    if args.distributed:
        dist.init_process_group(
            args.dist_backend,
            init_method=args.dist_url,
            world_size=world_size,
            rank=args.gpu,
        )

    # Explicitly setting seed to ensure the same dataset split and models created in
    # two processes (when distributed) start from the same random weights and biases
    seed_torch()

    if args.restore:
        dataset_state_dict_filename = args.dataset_state_dict_filename

        if dataset_state_dict_filename is None:
            warnings.warn(
                "Restore with `args.dataset_state_dict_filename` set to None.")
        elif not Path(dataset_state_dict_filename).exists():
            warnings.warn(f"`{dataset_state_dict_filename} not found; set "
                          f"args.dataset_state_dict_filename` to None")
            dataset_state_dict_filename = None
    else:
        dataset_state_dict_filename = None

    # convert reactions in csv file to atom mapped label file if necessary
    mols, attrs, labels = read_input_files(args.molecule_file,
                                           args.molecule_attributes_file,
                                           args.reaction_file)
    dataset = ReactionNetworkDataset(
        grapher=get_grapher(),
        molecules=mols,
        labels=labels,
        extra_features=attrs,
        feature_transformer=True,
        label_transformer=True,
        state_dict_filename=dataset_state_dict_filename,
    )

    trainset, valset, testset = train_validation_test_split(dataset,
                                                            validation=0.1,
                                                            test=0.1)

    if not args.distributed or (args.distributed and args.gpu == 0):
        torch.save(dataset.state_dict(), args.dataset_state_dict_filename)
        print("Trainset size: {}, valset size: {}: testset size: {}.".format(
            len(trainset), len(valset), len(testset)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
    else:
        train_sampler = None

    train_loader = DataLoaderReactionNetwork(
        trainset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
    )
    # larger val and test set batch_size is faster but needs more memory
    # adjust the batch size of to fit memory
    bs = max(len(valset) // 10, 1)
    val_loader = DataLoaderReactionNetwork(valset,
                                           batch_size=bs,
                                           shuffle=False)
    bs = max(len(testset) // 10, 1)
    test_loader = DataLoaderReactionNetwork(testset,
                                            batch_size=bs,
                                            shuffle=False)

    ### model

    feature_names = ["atom", "bond", "global"]
    set2set_ntypes_direct = ["global"]
    feature_size = dataset.feature_size

    args.feature_size = feature_size
    args.set2set_ntypes_direct = set2set_ntypes_direct

    # save args
    if not args.distributed or (args.distributed and args.gpu == 0):
        yaml_dump(args, "train_args.yaml")

    model = GatedGCNReactionNetwork(
        in_feats=args.feature_size,
        embedding_size=args.embedding_size,
        gated_num_layers=args.gated_num_layers,
        gated_hidden_size=args.gated_hidden_size,
        gated_num_fc_layers=args.gated_num_fc_layers,
        gated_graph_norm=args.gated_graph_norm,
        gated_batch_norm=args.gated_batch_norm,
        gated_activation=args.gated_activation,
        gated_residual=args.gated_residual,
        gated_dropout=args.gated_dropout,
        num_lstm_iters=args.num_lstm_iters,
        num_lstm_layers=args.num_lstm_layers,
        set2set_ntypes_direct=args.set2set_ntypes_direct,
        fc_num_layers=args.fc_num_layers,
        fc_hidden_size=args.fc_hidden_size,
        fc_batch_norm=args.fc_batch_norm,
        fc_activation=args.fc_activation,
        fc_dropout=args.fc_dropout,
        outdim=1,
        conv="GatedGCNConv",
    )

    if not args.distributed or (args.distributed and args.gpu == 0):
        print(model)

    if args.gpu is not None:
        model.to(args.gpu)
    if args.distributed:
        ddp_model = DDP(model, device_ids=[args.gpu])
        ddp_model.feature_before_fc = model.feature_before_fc
        model = ddp_model

    ### optimizer, loss, and metric
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    loss_func = MSELoss(reduction="mean")
    metric = WeightedL1Loss(reduction="sum")

    ### learning rate scheduler and stopper
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode="min",
                                  factor=0.4,
                                  patience=50,
                                  verbose=True)
    stopper = EarlyStopping(patience=150)

    # load checkpoint
    state_dict_objs = {
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler
    }
    if args.restore:
        try:

            if args.gpu is None:
                checkpoint = load_checkpoints(state_dict_objs,
                                              filename="checkpoint.pkl")
            else:
                # Map model to be loaded to specified single gpu.
                loc = "cuda:{}".format(args.gpu)
                checkpoint = load_checkpoints(state_dict_objs,
                                              map_location=loc,
                                              filename="checkpoint.pkl")

            args.start_epoch = checkpoint["epoch"]
            best = checkpoint["best"]
            print(
                f"Successfully load checkpoints, best {best}, epoch {args.start_epoch}"
            )

        except FileNotFoundError as e:
            warnings.warn(str(e) + " Continue without loading checkpoints.")
            pass

    # start training
    if not args.distributed or (args.distributed and args.gpu == 0):
        print(
            "\n\n# Epoch     Loss         TrainAcc        ValAcc     Time (s)")
        sys.stdout.flush()

    for epoch in range(args.start_epoch, args.epochs):
        ti = time.time()

        # In distributed mode, calling the set_epoch method is needed to make shuffling
        # work; each process will use the same random seed otherwise.
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train
        loss, train_acc = train(optimizer, model, feature_names, train_loader,
                                loss_func, metric, args.gpu)

        # bad, we get nan
        if np.isnan(loss):
            print("\n\nBad, we get nan for loss. Existing")
            sys.stdout.flush()
            sys.exit(1)

        # evaluate
        val_acc = evaluate(model, feature_names, val_loader, metric, args.gpu)

        if stopper.step(val_acc):
            pickle_dump(best,
                        args.output_file)  # save results for hyperparam tune
            break

        scheduler.step(val_acc)

        is_best = val_acc < best
        if is_best:
            best = val_acc

        # save checkpoint
        if not args.distributed or (args.distributed and args.gpu == 0):

            misc_objs = {"best": best, "epoch": epoch}

            save_checkpoints(
                state_dict_objs,
                misc_objs,
                is_best,
                msg=f"epoch: {epoch}, score {val_acc}",
            )

            tt = time.time() - ti

            print("{:5d}   {:12.6e}   {:12.6e}   {:12.6e}   {:.2f}".format(
                epoch, loss, train_acc, val_acc, tt))
            if epoch % 10 == 0:
                sys.stdout.flush()

    # load best to calculate test accuracy
    if args.gpu is None:
        checkpoint = load_checkpoints(state_dict_objs,
                                      filename="best_checkpoint.pkl")
    else:
        # Map model to be loaded to specified single  gpu.
        loc = "cuda:{}".format(args.gpu)
        checkpoint = load_checkpoints(state_dict_objs,
                                      map_location=loc,
                                      filename="best_checkpoint.pkl")

    if not args.distributed or (args.distributed and args.gpu == 0):
        test_acc = evaluate(model, feature_names, test_loader, metric,
                            args.gpu)

        print("\n#TestAcc: {:12.6e} \n".format(test_acc))
        print("\nFinish training at:", datetime.now())