コード例 #1
0
ファイル: train.py プロジェクト: anja121/hair-segmentation
def main():
    # read train config
    conf = read_config_file("configs/train_config.json")

    model_suffix = datetime.now().strftime("%Y%m%d-%H%M%S")
    model_name_prefix = str(conf["model_prefix"])

    # get dataset
    train_data_path = conf["path_to_dataset"] + "train/"
    valid_data_path = conf["path_to_dataset"] + "valid/"

    train_imgs = glob.glob(train_data_path + "images/*")
    train_masks = glob.glob(train_data_path + "masks/*")

    valid_imgs = glob.glob(valid_data_path + "images/*")
    valid_masks = glob.glob(valid_data_path + "masks/*")

    train_data = SemSegDataSet(img_paths=train_imgs,
                               mask_paths=train_masks,
                               img_size=(conf["img_size"], conf["img_size"]),
                               channels=(3, 1),
                               crop_percent_range=(0.75, 0.95),
                               seed=42
                               )

    valid_data = SemSegDataSet(img_paths=valid_imgs,
                               mask_paths=valid_masks,
                               img_size=(conf["img_size"], conf["img_size"]),
                               channels=(3, 1),
                               crop_percent_range=(0.75, 0.95),
                               seed=42
                               )

    train_size = train_data.size
    valid_size = valid_data.size

    train_data = train_data.batch(batch_size=conf["batch_size"],
                                  shuffle=False,
                                  shuffle_buffer=train_data.size)

    valid_data = valid_data.batch(batch_size=conf["batch_size"])

    # get model
    model = get_model(conf["img_size"])
    model.summary()

    # get train callback functions
    callbacks = get_callbacks(conf, model_suffix)

    # train model
    model.fit(train_data,
              epochs=conf["num_epochs"],
              steps_per_epoch=train_size//conf["batch_size"],
              validation_data=valid_data if valid_data is not None else None,
              validation_steps=valid_size//conf["batch_size"] if valid_data is not None else None,
              callbacks=callbacks)

    # save model
    model.save(model_name_prefix + "_" + model_suffix + '.h5')
コード例 #2
0
        args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
        args.dataset, args.model, args.iid, args.num_users, args.frac,
        args.local_ep, args.shard_per_user, args.results_save)
    if not os.path.exists(os.path.join(base_dir, 'local')):
        os.makedirs(os.path.join(base_dir, 'local'), exist_ok=True)

    dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(
        args)
    dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
    with open(dict_save_path, 'rb') as handle:
        dict_users_train, dict_users_test = pickle.load(handle)

    # build model
    net_glob = get_model(args)
    net_glob.train()

    net_local_list = []
    for user_ix in range(args.num_users):
        net_local_list.append(copy.deepcopy(net_glob))

    # training
    results_save_path = os.path.join(base_dir, 'local/results.csv')

    loss_train = []
    net_best = None
    best_loss = None
    best_acc = None
    best_epoch = None