def main(): global plotter plotter = VisdomLinePlotter(env_name=config.visdom_name) # instantiate model and initialize weights model = SiameseNetwork() if config.cuda: model.cuda() optimizer = create_optimizer(model, config.lr) # optionally resume from a checkpoint if config.resume: if os.path.isfile(config.resume): print('=> loading checkpoint {}'.format(config.resume)) checkpoint = torch.load(config.resume) config.start_epoch = checkpoint['epoch'] checkpoint = torch.load(config.resume) model.load_state_dict(checkpoint['state_dict']) else: print('=> no checkpoint found at {}'.format(config.resume)) start = config.start_epoch end = start + config.epochs for epoch in range(start, end): train(train_loader, model, optimizer, epoch)
model_name = config.model_name prediction_file = config.prediction_file best_prediction_file = config.best_prediction_file #DBY batch = config.batch mode = config.mode # create model model = SiameseNetwork() #model = SiameseEfficientNet() model = Vgg19() if mode == 'test': load_model(model_name, model) if cuda: model = model.cuda() # Define 'best loss' - DBY best_loss = 0.1 last_loss = 0 if mode == 'train': # define loss function # loss_fn = nn.CrossEntropyLoss() # if cuda: # loss_fn = loss_fn.cuda() class ContrastiveLoss(torch.nn.Module): """ Contrastive loss function. Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf """ def __init__(self, margin=2.0):