Ejemplo n.º 1
0
def main():
    """Interface for training and evaluating using the command line"""
    global args
    args = parser.parse_args()

    model = SiameseNetwork(1, args.embedding_size)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # If a checkpoint provided, load it's values
    if args.checkpoint:
        state = torch.load(args.checkpoint, map_location=device)
        model.load_state_dict(state['state_dict'])
    else:
        state = None

    # Run the model on a GPU if available
    model.to(device)


    # Train the network
    if args.mode == 'train':
        dataset = GEDDataset(args.data, which_set='train', adj_dtype=np.float32, transform=None)
        model, optimiser, epoch = train(model, dataset, batch_size=args.batch_size, embed_size=args.embedding_size, num_epochs=args.epochs,
              learning_rate=args.learning_rate, save_to=args.save_dir, resume_state=args.checkpoint, device=device)

    if args.save_dir:
        # Save the model checkpoint
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimiser': optimiser.state_dict(),
        }
        save_checkpoint(state, args.save_dir)

    # Whether to store the predictions from eval for plotting
    store_res = args.make_plot

    if args.mode == 'train' and args.post_training_eval:
        args.which_set = 'val'
    if args.mode == 'eval' or args.post_training_eval:
        dataset = GEDDataset(args.data, which_set=args.which_set, adj_dtype=np.float32, transform=None)
        results = eval(model, dataset, batch_size=args.batch_size, store_results=store_res, device=device)

    # Finally, if plotting the results:
    if args.make_plot:
        # Assert that the data has been evaluated
        if not (args.mode == 'eval' or args.post_training_eval):
            raise AttributeError('The flags provided did not specify to evaluate the dataset, which is required for'
                                 'plotting')
        # Make a plot of the results
        print('Making the plot')
        plot_prediction(results[0], results[1])
Ejemplo n.º 2
0
    n_val = int(len(dataset) * args.val_size)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=args.batch_size)
    val_loader = DataLoader(val, batch_size=args.batch_size // 4)

    # load backbones
    print("[*] Initializing weights...")
    imagenet_net = ResNet34()
    sketches_net = ResNet34()
    # sketches_net.load_state_dict(torch.load(args.sketches_backbone_weights))
    print("[+] Weights loaded")

    print("[*] Initializing model, loss and optimizer")
    contrastive_net = SiameseNetwork(sketches_net, imagenet_net)
    contrastive_net.to(args.device)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(contrastive_net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    else:
        optimizer = torch.optim.Adam(contrastive_net.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=args.t_0)
    contrastive_loss = contrastive_loss()
    cross_entropy_loss = torch.nn.CrossEntropyLoss()
    print("[+] Model, loss and optimizer were initialized successfully")

    if not args.debug:
        wandb.init(project='homework1-cc7221', entity='p137')
        config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

autoencoder = AutoEncoder(config)
siamese_network = SiameseNetwork(config)

autoencoder_file = '/autoencoder_epoch175_loss1.1991.pth'
siamese_file = '/siamese_network_epoch175_loss1.1991.pth'

if config.load_model:
    autoencoder.load_state_dict(torch.load(config.saved_models_folder + autoencoder_file))
    siamese_network.load_state_dict(torch.load(config.saved_models_folder + siamese_file))

autoencoder.to(device)
autoencoder.train()

siamese_network.to(device)
siamese_network.train()

params = list(autoencoder.parameters()) + list(siamese_network.parameters())

optimizer = torch.optim.Adam(params, lr=config.lr, betas=(0.9, 0.999))

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    # transforms.RandomCrop(size=128),
    # transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
])
train_data = torchvision.datasets.ImageFolder(config.data_folder, transform=transform)
train_data_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True)
Ejemplo n.º 4
0
) if Config.network == 'siamese' else TriMadoriDataset()
val_dataset = MadoriDataset(
    train=False) if Config.network == 'siamese' else TriMadoriDataset(
        train=False)

# data loaders
train_dataloader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=Config.batch_size)
val_dataloader = DataLoader(val_dataset,
                            shuffle=False,
                            batch_size=Config.batch_size)

# models
net = SiameseNetwork() if Config.network == 'siamese' else TripletNetwork()
net = net.to(device)
criterion = ContrastiveLoss() if Config.network == 'siamese' else TripletLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)


def train_siamese():
    train_loss_history, val_loss_history = [], []
    lowest_epoch_train_loss = lowest_epoch_val_loss = float('inf')

    for epoch in tqdm(range(Config.train_number_epochs)):
        # training
        net.train()
        epoch_train_loss = 0
        for batch_no, data in enumerate(train_dataloader):
            img0, img1, label = data
            img0, img1, label = img0.to(device), img1.to(device), label.to(
Ejemplo n.º 5
0
    train_loader = DataLoader(train, batch_size=args.batch_size)
    val_loader = DataLoader(val, batch_size=args.batch_size)
    print("[+] Dataset initialized successfully")

    # load backbones
    print("[*] Initializing weights...")
    imagenet_net = ResNet34()
    sketches_net = ResNet34()
    # sketches_net.load_state_dict(torch.load(args.sketches_backbone_weights))
    print("[+] Weights loaded")

    print("[*] Adapting output layers...")

    print("[*] Initializing model, loss and optimizer")
    siamese_net = SiameseNetwork(sketches_net, imagenet_net)
    siamese_net.to(args.device)
    optimizer = torch.optim.Adam(siamese_net.parameters(), lr=args.lr)
    triplet_loss = triplet_loss()
    cross_entropy_loss = torch.nn.CrossEntropyLoss()
    print("[+] Model, loss and optimizer were initialized successfully")

    if not args.debug:
        wandb.init(project='homework1-cc7221', entity='p137')
        config = wandb.config
        config.model = siamese_net.__class__.__name__ + "_triplet"
        config.device = device
        config.batch_size = args.batch_size
        config.epochs = args.epochs
        config.learning_rate = args.lr

    print("[*] Training")