diff_loss = criterion(image, negative_sample) # Contrastive loss function loss = (same_loss ** 2).mean() + (torch.max(M - diff_loss, torch.zeros_like(diff_loss)) ** 2).mean() # Zero the parameter gradients optimizer.zero_grad() # Compute gradient using backpropagation loss.backward() # Take a optimizer step optimizer.step() # Save batch loss epoch_loss += loss.data.item() # Monitoring performance if i % display_step == 0: print('Epoch: %2d, step: %4d, mean sinkhorn divergence: %.4f' % (epoch, i, epoch_loss / i)) i += 1 # break # Saving model print('\nRunning time: {}\n'.format(diff(datetime.now(), time))) torch.save(model.state_dict(), "../models/point_cloud_embedding.model")
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Siamese Network model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader, layer_hyperparams): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator. - layer_hyperparams: dict containing layer-wise hyperparameters such as the initial learning rate, the end momentum, and the l2 regularization strength. """ self.config = config self.layer_hyperparams = layer_hyperparams if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.dataset) self.num_valid = self.valid_loader.dataset.trials else: self.test_loader = data_loader self.num_test = self.test_loader.dataset.trials self.model = SiameseNet() if config.use_gpu: self.model.cuda() # model params self.num_params = sum( [p.data.nelement() for p in self.model.parameters()]) self.num_model = get_num_model(config) self.num_layers = len(list(self.model.children())) print('[*] Number of model parameters: {:,}'.format(self.num_params)) # path params self.ckpt_dir = os.path.join(config.ckpt_dir, self.num_model) self.logs_dir = os.path.join(config.logs_dir, self.num_model) # misc params self.resume = config.resume self.use_gpu = config.use_gpu self.dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) # optimization params self.best = config.best self.best_valid_acc = 0. self.epochs = config.epochs self.start_epoch = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.counter = 0 # grab layer-wise hyperparams self.init_lrs = self.layer_hyperparams['layer_init_lrs'] self.init_momentums = [config.init_momentum] * self.num_layers self.end_momentums = self.layer_hyperparams['layer_end_momentums'] self.l2_regs = self.layer_hyperparams['layer_l2_regs'] # compute temper rate for momentum if self.epochs == 1: f = lambda max, min: min else: f = lambda max, min: (max - min) / (self.epochs - 1) self.momentum_temper_rates = [ f(x, y) for x, y in zip(self.end_momentums, self.init_momentums) ] # set global learning rates and momentums self.lrs = self.init_lrs self.momentums = self.init_momentums # # initialize optimizer # optim_dict = [] # for i, layer in enumerate(self.model.children()): # group = {} # group['params'] = layer.parameters() # group['lr'] = self.lrs[i] # group['momentum'] = self.momentums[i] # group['weight_decay'] = self.l2_regs[i] # optim_dict.append(group) # self.optimizer = optim.SGD(optim_dict) # self.optimizer = optim.SGD( # self.model.parameters(), lr=1e-3, momentum=0.9, weight_decay=4e-4, # ) self.optimizer = optim.Adam( self.model.parameters(), lr=3e-4, weight_decay=6e-5, ) # # learning rate scheduler # self.scheduler = StepLR( # self.optimizer, step_size=self.lr_patience, gamma=0.99, # ) def train(self): if self.resume: self.load_checkpoint(best=False) # switch to train mode self.model.train() # create train and validation log files optim_file = open(os.path.join(self.logs_dir, 'optim.csv'), 'w') train_file = open(os.path.join(self.logs_dir, 'train.csv'), 'w') valid_file = open(os.path.join(self.logs_dir, 'valid.csv'), 'w') print("\n[*] Train on {} sample pairs, validate on {} trials".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): # self.decay_lr() # self.temper_momentum(epoch) # # # log lrs and momentums # n = self.num_layers # msg = ( # "{}, " + ", ".join(["{}"] * n) + ", " + ", ".join(["{}"] * n) # ) # optim_file.write(msg.format( # epoch, *self.momentums, *self.lrs) # ) print('\nEpoch: {}/{}'.format(epoch + 1, self.epochs)) train_loss = self.train_one_epoch(epoch, train_file) valid_acc = self.validate(epoch, valid_file) # check for improvement is_best = valid_acc > self.best_valid_acc msg = "train loss: {:.3f} - val acc: {:.3f}" if is_best: msg += " [*]" self.counter = 0 print(msg.format(train_loss, valid_acc)) # checkpoint the model if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) # release resources optim_file.close() train_file.close() valid_file.close() def train_one_epoch(self, epoch, file): train_batch_time = AverageMeter() train_losses = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x1, x2, y) in enumerate(self.train_loader): if self.use_gpu: x1, x2, y = x1.cuda(), x2.cuda(), y.cuda() x1, x2, y = Variable(x1), Variable(x2), Variable(y) # split input pairs along the batch dimension batch_size = x1.shape[0] out = self.model(x1, x2) loss = F.binary_cross_entropy_with_logits(out, y) # compute gradients and update self.optimizer.zero_grad() loss.backward() self.optimizer.step() # store batch statistics toc = time.time() train_losses.update(loss.data[0], batch_size) train_batch_time.update(toc - tic) tic = time.time() pbar.set_description(("{:.1f}s - loss: {:.3f}".format( train_batch_time.val, train_losses.val, ))) pbar.update(batch_size) # log loss iter = (epoch * len(self.train_loader)) + i file.write('{},{}\n'.format(iter, train_losses.val)) return train_losses.avg def validate(self, epoch, file): # switch to evaluate mode self.model.eval() correct = 0 for i, (x1, x2) in enumerate(self.valid_loader): if self.use_gpu: x1, x2 = x1.cuda(), x2.cuda() x1, x2 = Variable(x1, volatile=True), Variable(x2, volatile=True) batch_size = x1.shape[0] # compute log probabilities out = self.model(x1, x2) log_probas = F.sigmoid(out) # get index of max log prob pred = log_probas.data.max(0)[1][0] if pred == 0: correct += 1 # compute acc and log valid_acc = (100. * correct) / self.num_valid iter = epoch file.write('{},{}\n'.format(iter, valid_acc)) return valid_acc def test(self): # load best model self.load_checkpoint(best=self.best) # switch to evaluate mode self.model.eval() correct = 0 for i, (x1, x2) in enumerate(self.test_loader): if self.use_gpu: x1, x2 = x1.cuda(), x2.cuda() x1, x2 = Variable(x1, volatile=True), Variable(x2, volatile=True) batch_size = x1.shape[0] # compute log probabilities out = self.model(x1, x2) log_probas = F.sigmoid(out) # get index of max log prob pred = log_probas.data.max(0)[1][0] if pred == 0: correct += 1 test_acc = (100. * correct) / self.num_test print("[*] Test Acc: {}/{} ({:.2f}%)".format(correct, self.num_test, test_acc)) def temper_momentum(self, epoch): """ This function linearly increases the per-layer momentum to a predefined ceiling over a set amount of epochs. """ if epoch == 0: return self.momentums = [ x + y for x, y in zip(self.momentums, self.momentum_temper_rates) ] for i, param_group in enumerate(self.optimizer.param_groups): param_group['momentum'] = self.momentums[i] def decay_lr(self): """ This function linearly decays the per-layer lr over a set amount of epochs. """ self.scheduler.step() for i, param_group in enumerate(self.optimizer.param_groups): self.lrs[i] = param_group['lr'] def save_checkpoint(self, state, is_best): filename = 'model_ckpt.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = 'best_model_ckpt.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): print("[*] Loading model from {}".format(self.ckpt_dir)) filename = 'model_ckpt.tar' if best: filename = 'best_model_ckpt.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
net.eval() print("\n ...Valid") sensitivity_valid = [] for _, (valid1, valid2, label_valid) in tqdm(enumerate(validLoader, 1)): if Flags.cuda: test1, test2 = valid1.cuda(), valid2.cuda() else: test1, test2 = Variable(valid1), Variable(valid2) output_net = net.forward(test1, test2) y_actual = [] y_hat = [] for i in range(output_net.size()[0]): output_net_np = math.ceil(output_net[i].data.cpu().numpy()) y_actual.append(1) if output_net_np == 1.0 or output_net_np == 1: y_hat.append(1) else: y_hat.append(0) TP, FP, TN, FN = measure(y_actual, y_hat) if TP == 0 or FN == 0: sensitivity = 0 else: sensitivity = 100*(TP/(TP+FN)) sensitivity_valid.append(sensitivity) sensitivity_list.append(np.mean(sensitivity_valid)) plot_sensitivity(sensitivity_list,save_path) if epoch % Flags.save_every == 0: print("\n ...Save model") torch.save(net.state_dict(),os.path.join(model_path,"model_"+str(epoch_valid)+'.pt')) epoch_valid += 1
def train(self): # Dataloader train_loader, valid_loader = get_train_validation_loader( self.config.data_dir, self.config.batch_size, self.config.num_train, self.config.augment, self.config.way, self.config.valid_trials, self.config.shuffle, self.config.seed, self.config.num_workers, self.config.pin_memory) # Model, Optimizer, criterion model = SiameseNet() if self.config.optimizer == "SGD": optimizer = optim.SGD(model.parameters(), lr=self.config.lr) else: optimizer = optim.Adam(model.parameters()) criterion = torch.nn.BCEWithLogitsLoss() if self.config.use_gpu: model.cuda() # Load check point if self.config.resume: start_epoch, best_epoch, best_valid_acc, model_state, optim_state = self.load_checkpoint( best=False) model.load_state_dict(model_state) optimizer.load_state_dict(optim_state) one_cycle = OneCyclePolicy(optimizer, self.config.lr, (self.config.epochs - start_epoch) * len(train_loader), momentum_rng=[0.85, 0.95]) else: best_epoch = 0 start_epoch = 0 best_valid_acc = 0 one_cycle = OneCyclePolicy(optimizer, self.config.lr, self.config.epochs * len(train_loader), momentum_rng=[0.85, 0.95]) # create tensorboard summary and add model structure. writer = SummaryWriter(os.path.join(self.config.logs_dir, 'logs'), filename_suffix=self.config.num_model) im1, im2, _ = next(iter(valid_loader)) writer.add_graph(model, [im1.to(self.device), im2.to(self.device)]) counter = 0 num_train = len(train_loader) num_valid = len(valid_loader) print( f"[*] Train on {len(train_loader.dataset)} sample pairs, validate on {valid_loader.dataset.trials} trials" ) # Train & Validation main_pbar = tqdm(range(start_epoch, self.config.epochs), initial=start_epoch, position=0, total=self.config.epochs, desc="Process") for epoch in main_pbar: train_losses = AverageMeter() valid_losses = AverageMeter() # TRAIN model.train() train_pbar = tqdm(enumerate(train_loader), total=num_train, desc="Train", position=1, leave=False) for i, (x1, x2, y) in train_pbar: if self.config.use_gpu: x1, x2, y = x1.to(self.device), x2.to(self.device), y.to( self.device) out = model(x1, x2) loss = criterion(out, y.unsqueeze(1)) # compute gradients and update optimizer.zero_grad() loss.backward() optimizer.step() one_cycle.step() # store batch statistics train_losses.update(loss.item(), x1.shape[0]) # log loss writer.add_scalar("Loss/Train", train_losses.val, epoch * len(train_loader) + i) train_pbar.set_postfix_str(f"loss: {train_losses.val:0.3f}") # VALIDATION model.eval() valid_acc = 0 correct_sum = 0 valid_pbar = tqdm(enumerate(valid_loader), total=num_valid, desc="Valid", position=1, leave=False) with torch.no_grad(): for i, (x1, x2, y) in valid_pbar: if self.config.use_gpu: x1, x2, y = x1.to(self.device), x2.to( self.device), y.to(self.device) # compute log probabilities out = model(x1, x2) loss = criterion(out, y.unsqueeze(1)) y_pred = torch.sigmoid(out) y_pred = torch.argmax(y_pred) if y_pred == 0: correct_sum += 1 # store batch statistics valid_losses.update(loss.item(), x1.shape[0]) # compute acc and log valid_acc = correct_sum / num_valid writer.add_scalar("Loss/Valid", valid_losses.val, epoch * len(valid_loader) + i) valid_pbar.set_postfix_str(f"accuracy: {valid_acc:0.3f}") writer.add_scalar("Acc/Valid", valid_acc, epoch) # check for improvement if valid_acc > best_valid_acc: is_best = True best_valid_acc = valid_acc best_epoch = epoch counter = 0 else: is_best = False counter += 1 # checkpoint the model if counter > self.config.train_patience: print("[!] No improvement in a while, stopping training.") return if is_best or epoch % 5 == 0 or epoch == self.config.epochs: self.save_checkpoint( { 'epoch': epoch, 'model_state': model.state_dict(), 'optim_state': optimizer.state_dict(), 'best_valid_acc': best_valid_acc, 'best_epoch': best_epoch, }, is_best) main_pbar.set_postfix_str( f"best acc: {best_valid_acc:.3f} best epoch: {best_epoch} ") tqdm.write( f"[{epoch}] train loss: {train_losses.avg:.3f} - valid loss: {valid_losses.avg:.3f} - valid acc: {valid_acc:.3f} {'[BEST]' if is_best else ''}" ) # release resources writer.close()
def main() : hp = get_hparams() transform = transforms.Compose([ transforms.ToTensor()]) train_loader = get_loader(hp.bg_data_path, hp.ev_data_path, hp.batch_size, hp.dataset_size, True, transform, mode="train") valid_loader = get_loader(hp.bg_data_path, hp.ev_data_path, hp.num_way, hp.valid_trial * hp.num_way, False, transform, mode="valid") test_loader = get_loader(hp.bg_data_path, hp.ev_data_path, hp.num_way, hp.test_trial * hp.num_way, False, transform, mode="test") model = SiameseNet().to(device) def weights_init(m) : if isinstance(m, nn.Conv2d) : torch.nn.init.normal_(m.weight, 0.0, 1e-2) torch.nn.init.normal_(m.bias, 0.5, 1e-2) if isinstance(m, nn.Linear) : torch.nn.init.normal_(m.weight, 0.0, 0.2) torch.nn.init.normal_(m.bias, 0.5, 1e-2) model.apply(weights_init) num_epochs = hp.num_epochs total_step = len(train_loader) stop_decision = 1 prev_error = 0.0 for epoch in range(num_epochs) : lr = hp.learning_rate * pow(0.99, epoch) optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=hp.momentum, weight_decay=hp.reg_scale) for i, (images_1, images_2, label) in enumerate(train_loader) : images_1 = images_1.to(device).float() images_2 = images_2.to(device).float() label = label.to(device).float() prob = model(images_1, images_2) obj = label * torch.log(prob) + (1. - label) * torch.log(1. - prob) loss = -torch.sum(obj) / float(hp.batch_size) optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % hp.log_step == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item())) valid_errors = 0.0 total_sample = 0.0 for images_1, images_2, label in valid_loader : images_1 = images_1.to(device).float() images_2 = images_2.to(device).float() label = label.to(device).float() prob = model(images_1, images_2) obj = label * torch.log(prob) + (1. - label) * torch.log(1. - prob) valid_errors += -torch.sum(obj).detach().cpu().numpy() / float(hp.num_way) total_sample += 1.0 valid_error = np.round(valid_errors / total_sample, 4) print('Epoch [{}/{}], Validation Error : {:.4f}' .format(epoch+1, num_epochs, valid_error)) if valid_error == prev_error : stop_decision += 1 else : stop_decision = 1 if stop_decision == 20 : print('Epoch [{}/{}], Early Stopped Training!'.format(epoch+1, num_epochs)) torch.save(model.state_dict(), os.path.join( hp.model_path, 'siamese-{}.ckpt'.format(epoch+1))) break prev_error = valid_error if (epoch + 1) % 20 == 0 : torch.save(model.state_dict(), os.path.join( hp.model_path, 'siamese-{}.ckpt'.format(epoch+1)))