예제 #1
0
def main():
    """Interface for training and evaluating using the command line"""
    global args
    args = parser.parse_args()

    model = SiameseNetwork(1, args.embedding_size)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # If a checkpoint provided, load it's values
    if args.checkpoint:
        state = torch.load(args.checkpoint, map_location=device)
        model.load_state_dict(state['state_dict'])
    else:
        state = None

    # Run the model on a GPU if available
    model.to(device)


    # Train the network
    if args.mode == 'train':
        dataset = GEDDataset(args.data, which_set='train', adj_dtype=np.float32, transform=None)
        model, optimiser, epoch = train(model, dataset, batch_size=args.batch_size, embed_size=args.embedding_size, num_epochs=args.epochs,
              learning_rate=args.learning_rate, save_to=args.save_dir, resume_state=args.checkpoint, device=device)

    if args.save_dir:
        # Save the model checkpoint
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimiser': optimiser.state_dict(),
        }
        save_checkpoint(state, args.save_dir)

    # Whether to store the predictions from eval for plotting
    store_res = args.make_plot

    if args.mode == 'train' and args.post_training_eval:
        args.which_set = 'val'
    if args.mode == 'eval' or args.post_training_eval:
        dataset = GEDDataset(args.data, which_set=args.which_set, adj_dtype=np.float32, transform=None)
        results = eval(model, dataset, batch_size=args.batch_size, store_results=store_res, device=device)

    # Finally, if plotting the results:
    if args.make_plot:
        # Assert that the data has been evaluated
        if not (args.mode == 'eval' or args.post_training_eval):
            raise AttributeError('The flags provided did not specify to evaluate the dataset, which is required for'
                                 'plotting')
        # Make a plot of the results
        print('Making the plot')
        plot_prediction(results[0], results[1])
def train(labellist, batch_size, train_number_epochs, learning_rate, round,
          device):

    train_dataset = SiameseDataset('images.txt', labellist,
                                   transforms.ToTensor())
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)

    # training
    net = SiameseNetwork().to(device)
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)

    counter = []
    loss_history = []
    iteration_number = 0

    for epoch in range(train_number_epochs):
        total_loss = 0
        start_time = datetime.now()
        for i, data in enumerate(train_dataloader):
            img0, img1, label = data
            img0, img1, label = img0.to(device), img1.to(device), label.to(
                device)

            optimizer.zero_grad()
            output1, output2 = net(img0, img1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            total_loss += loss_contrastive.item()
            optimizer.step()
            if i % 20 == 0:
                iteration_number += 20
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.item())
        end_time = datetime.now()
        print("Epoch number: {} , Current loss: {:.4f}, Epoch Time: {}".format(
            epoch + 1, total_loss / (i + 1), end_time - start_time))

    torch.save(net.state_dict(),
               "parameters/Parm_Round_" + str(round + 1) + ".pth")
    return counter, loss_history
예제 #3
0
            # accuracy sketches
            _, max_idx = torch.max(pred_logits_second, dim=1)
            running_acc_sketches += torch.sum(max_idx == second_label).item()
            avg_sketches_val_acc = running_acc_sketches / items * 100

        avg_val_loss = val_total_loss / len(val_loader)

        if not args.debug:
            wandb.log({
                'val/loss': avg_val_loss,
                'val/acc flickr': avg_val_flicker_acc,
                'val/acc sketches': avg_sketches_val_acc,
                'epoch': epoch + 1
            })

            # checkpointing
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
            model_name = f"best_{contrastive_net.__class__.__name__}_contrastive.pth"
            torch.save(contrastive_net.state_dict(), model_name)
            if not args.debug:
                wandb.save(model_name)
            tag = '*'

            sys.stdout.write(
                f", Val[Loss: {avg_val_loss}, flickr acc: {avg_val_flicker_acc}, sketches acc: {avg_sketches_val_acc} {tag}"
            )
            sys.stdout.flush()
            sys.stdout.write('\n')
        loss = siamese_loss + ae_loss
        # loss = loss_function(img_batch, decoded_img)

        # print(loss.item())
        loss.backward()

        optimizer.step()

    epoch_loss = epoch_loss_autoencoder + epoch_loss_siamese
    print(f'siamese: {epoch_loss_siamese}, autoencoder: {epoch_loss_autoencoder}, all: {epoch_loss}')
    
    intloss = int(epoch_loss * 10000) / 10000
    if epoch % config.save_frequency == 0:
        torch.save(autoencoder.state_dict(), f'{config.saved_models_folder}/autoencoder_epoch{epoch}_loss{intloss}.pth')
        torch.save(siamese_network.state_dict(), f'{config.saved_models_folder}/siamese_network_epoch{epoch}_loss{intloss}.pth')
        print('Saved models, epoch: ' + str(epoch))

for batch in train_data_loader:
    batch = batch[0]
    break
batch = batch.to(device)
batch = transform2(batch)

print('test')
print('the same images')
img1 = batch[0]
features, decoded_img = autoencoder(img1.unsqueeze(0))
result = siamese_network(features, features)
print(result)
print()
예제 #5
0
        loss = loss_function(output, target)
        # print(loss.item())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(epoch_loss)

    intloss = int(epoch_loss * 10000) / 10000
    if epoch % config.save_frequency == 0:
        torch.save(
            encoder.state_dict(),
            f'{config.saved_models_folder}/encoder_epoch{epoch}_loss{intloss}.pth'
        )
        torch.save(
            siamese_network.state_dict(),
            f'{config.saved_models_folder}/siamese_network_epoch{epoch}_loss{intloss}.pth'
        )
        print('Saved models, epoch: ' + str(epoch))

for batch in train_data_loader:
    batch = batch[0]
    break
batch = batch.to(device)
batch = transform2(batch)

print('test')
print('the same images')
img1 = batch[0]
features = encoder(img1.unsqueeze(0))
result = siamese_network(features, features)
예제 #6
0
            _, max_idx = torch.max(pred_logits_negative, dim=1)
            running_acc_negative += torch.sum(max_idx == n_l).item()
            avg_acc_val_negative = running_acc_negative / items * 100

        avg_val_loss = val_total_loss / len(val_loader)

        if not args.debug:
            wandb.log({
                'val/loss': avg_val_loss,
                'val/acc anchor': avg_acc_val_anchor,
                'val/acc positive': avg_acc_val_positive,
                'val/acc negative': avg_acc_val_negative,
                'epoch': epoch + 1
            })

            # checkpointing
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
            model_name = f"best_{siamese_net.__class__.__name__}_triplet.pth"
            torch.save(siamese_net.state_dict(), model_name)
            if not args.debug:
                wandb.save(model_name)
            tag = '*'

            sys.stdout.write(f", Val[Loss: {avg_val_loss}, "
                             f"anchor acc: {avg_acc_val_anchor}, "
                             f"positive acc: {avg_acc_val_positive} "
                             f"negative acc: {avg_acc_val_negative} {tag}")
            sys.stdout.flush()
            sys.stdout.write('\n')
예제 #7
0
 if args.fit_mode:
     # +) tensorboardX, log
     if not os.path.exists('logs/tune'):
         os.mkdir('logs/tune')
     writer = SummaryWriter('logs/tune/{}'.format(model_tag))
     dev_losses = []
     for epoch in range(start, args.num_epochs + 1):
         train_loss = train_fit(train_loader, model, device, criterion,
                                optim)
         dev_loss = dev_fit(dev_loader, model, device, criterion)
         writer.add_scalar('train_loss', train_loss, epoch)
         writer.add_scalar('dev_loss', dev_loss, epoch)
         print('\n{} - train loss: {:.5f} - dev loss: {:.5f}'.format(
             epoch, train_loss, dev_loss))
         torch.save(
             model.state_dict(),
             os.path.join(model_save_path,
                          'epoch_{}.pth'.format(epoch)))
         dev_losses.append(dev_loss)
         #early_stopping(dev_loss, model)
         #if early_stopping.early_stop:
         #    print('early stopping !')
         #    break
         scheduler.step(dev_loss)
     minposs = dev_losses.index(min(dev_losses)) + 1
     print('lowest dev loss at epoch is {}'.format(minposs))
 # +) cp-fine-tuning mode
 elif args.cp_fit_mode:
     # +) tensorboardX, log
     if not os.path.exists('logs/cp-tune'):
         os.mkdir('logs/cp-tune')
예제 #8
0
print("===================={}-Way-{}-Shot-{}-{} Learning====================".format(class_per_set,
                                                                                     sample_per_class,
                                                                                     net,
                                                                                     dataset))

for e in range(n_epochs):
    ac = 0.0
    dicel = 0.0

    print("====================Epoch:{}====================".format(e))
    # train_one_epoch(data, batches, model, optimizer)

    tests = 500
    for i in range(tests):
        acc, dice_loss = test_nets(data, model, batch_size)
        ac += acc
        dicel += dice_loss

    print("====================Test: test_accuracy:{} test_dice_loss:{}====================".format(ac/tests, dicel/tests))

    with open(os.path.join('outputs/{}-way-{}-shot_{}_{}'.format(class_per_set, sample_per_class, net, dataset), 'metrics_table.csv'), 'a') as f:
        writer = csv.writer(f)
        writer.writerow([e, ac/tests, dicel/tests])

    if e % args.eval_epochs == 0:
        torch.save(model.state_dict(), 'outputs/{}-way-{}-shot_{}_{}/model_{}.pkl'.format(class_per_set,
                                                                                          sample_per_class,
                                                                                          net,
                                                                                          dataset,
                                                                                          e))