import config
from pathlib import Path
from utils_model import train_resnet
from utils import (get_classes, get_log_csv_name)

# Training the ResNet.
print("\n\n+++++ Running 3_train.py +++++")
train_resnet(batch_size=config.args.batch_size,
             checkpoints_folder=Path('checkpoints_incorrect'),
             classes=config.classes,
             color_jitter_brightness=config.args.color_jitter_brightness,
             color_jitter_contrast=config.args.color_jitter_contrast,
             color_jitter_hue=config.args.color_jitter_hue,
             color_jitter_saturation=config.args.color_jitter_saturation,
             device=config.device,
             learning_rate=config.args.learning_rate,
             learning_rate_decay=config.args.learning_rate_decay,
             log_csv=get_log_csv_name(log_folder=Path('logs/incorrect')),
             num_classes=config.num_classes,
             num_layers=config.args.num_layers,
             num_workers=config.args.num_workers,
             path_mean=config.path_mean,
             path_std=config.path_std,
             pretrain=config.args.pretrain,
             resume_checkpoint=True,
             resume_checkpoint_path=Path('checkpoints/resnet18_e10_va0.55588.pt'),
             save_interval=config.args.save_interval,
             num_epochs=config.args.num_epochs,
             train_folder=Path('data/voc_trainval_incorrect/'),
             weight_decay=config.args.weight_decay)
print("+++++ Finished running 3_train.py +++++\n\n")
示例#2
0
# DeepSlide
# Jason Wei, Behnaz Abdollahi, Saeed Hassanpour

# Training the resnet

from utils_model import train_resnet

if __name__ == '__main__':
    train_resnet(train_folder=config.train_folder,
                 num_epochs=config.num_epochs,
                 num_layers=config.num_layers,
                 learning_rate=config.learning_rate,
                 batch_size=config.batch_size,
                 weight_decay=config.weight_decay,
                 learning_rate_decay=config.learning_rate_decay,
                 resume_checkpoint=config.resume_checkpoint,
                 resume_checkpoint_path=config.resume_checkpoint_path,
                 save_interval=config.save_interval,
                 checkpoints_folder=config.checkpoints_folder,
                 pretrain=config.pretrain,
                 log_csv=config.log_csv)
示例#3
0
import config
from utils_model import train_resnet

# Training the ResNet.
print("\n\n+++++ Running 3_train.py +++++")
train_resnet(batch_size=config.args.batch_size,
             checkpoints_folder=config.args.checkpoints_folder,
             classes=config.classes,
             color_jitter_brightness=config.args.color_jitter_brightness,
             color_jitter_contrast=config.args.color_jitter_contrast,
             color_jitter_hue=config.args.color_jitter_hue,
             color_jitter_saturation=config.args.color_jitter_saturation,
             device=config.device,
             learning_rate=config.args.learning_rate,
             learning_rate_decay=config.args.learning_rate_decay,
             log_csv=config.log_csv,
             num_classes=config.num_classes,
             num_layers=config.args.num_layers,
             num_workers=config.args.num_workers,
             path_mean=config.path_mean,
             path_std=config.path_std,
             pretrain=config.args.pretrain,
             resume_checkpoint=config.args.resume_checkpoint,
             resume_checkpoint_path=config.resume_checkpoint_path,
             save_interval=config.args.save_interval,
             num_epochs=config.args.num_epochs,
             train_folder=config.args.train_folder,
             weight_decay=config.args.weight_decay)
print("+++++ Finished running 3_train.py +++++\n\n")
示例#4
0
# Training the ResNet.
print("\n\n+++++ Running 3_train.py +++++")
train_resnet(
    batch_size=256,
    checkpoints_folder=Path(
        "/home/brenta/scratch/jason/checkpoints/image_net/grad_cl/exp_" +
        str(exp_num)),
    classes=classes,
    color_jitter_brightness=0,
    color_jitter_contrast=0,
    color_jitter_hue=0,
    color_jitter_saturation=0,
    device=device,
    learning_rate=0.0001,
    learning_rate_decay=0.5,
    log_csv=log_csv,
    train_order_csv=train_order_csv,
    num_classes=num_classes,
    num_layers=18,
    num_workers=8,
    path_mean=path_mean,
    path_std=path_std,
    pretrain=False,
    resume_checkpoint=False,
    resume_checkpoint_path=None,
    save_interval=0,
    num_epochs=200,
    train_folder=train_folder,
    weight_decay=1e-4)
print("+++++ Finished running 3_train.py +++++\n\n")
from utils import (get_classes, get_log_csv_name, get_log_csv_train_order)
from utils_model import train_resnet

exp_num = 62

train_folder = Path(
    "/home/brenta/scratch/jason/data/imagenet/grad_mb10000_0.5/train")
val_folder = Path(
    "/home/brenta/scratch/jason/data/imagenet/grad_mb10000_0.5/val")
checkpoints_folder = Path(
    "/home/brenta/scratch/jason/checkpoints/imagenet/grad_pred/exp_" +
    str(exp_num))
log_folder = Path("/home/brenta/scratch/jason/logs/imagenet/grad_pred/exp_" +
                  str(exp_num))
log_csv = get_log_csv_name(log_folder=log_folder)
train_order_csv = get_log_csv_train_order(log_folder=log_folder)
classes = get_classes(train_folder)
num_classes = len(classes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_resnet(train_folder=train_folder,
             val_folder=val_folder,
             checkpoints_folder=checkpoints_folder,
             train_order_csv=None,
             log_csv=log_csv,
             classes=classes,
             num_classes=num_classes,
             device=device,
             save_mb_interval=10000,
             val_mb_interval=10000,
             num_epochs=10)
print("\n\n+++++ Running 3_train.py +++++")
train_resnet(
    batch_size=256,
    checkpoints_folder=Path(
        "/home/brenta/scratch/jason/checkpoints/voc/vanilla/exp_" +
        str(exp_num)),
    classes=classes,
    color_jitter_brightness=0,
    color_jitter_contrast=0,
    color_jitter_hue=0,
    color_jitter_saturation=0,
    device=device,
    learning_rate=0.0001,
    learning_rate_decay=0.5,
    log_csv=log_csv,
    train_order_csv=train_order_csv,
    num_classes=num_classes,
    num_layers=18,
    num_workers=8,
    path_mean=path_mean,
    path_std=path_std,
    pretrain=False,
    resume_checkpoint=True,
    resume_checkpoint_path=Path(
        "/home/brenta/scratch/jason/checkpoints/voc/vanilla/exp_44/resnet18_e0_mb40_va0.45239.pt"
    ),
    save_interval=0,
    num_epochs=1,
    train_folder=train_folder,
    weight_decay=1e-4)
print("+++++ Finished running 3_train.py +++++\n\n")