def train_model(train_dl,model): criterion = nn.CrossEntropyLoss() optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) n_total_steps = len(train_dl) #print(n_total_steps) for epoch in range(num_epoch): if epoch % 2== 0: checkpoint = {'state_dict': model.state_dict(),'optimizer':optimizer.state_dict()} save_checkpoint(checkpoint) for i, (inputs,targets) in enumerate(train_dl): inputs = inputs.to(device) targets = targets.to(device) #print(targets) # Forward Pass yhat = model(inputs) loss = criterion(yhat,targets) # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 2 == 0: print("Epoch:",epoch+1/num_epoch,"Step:", i+1/n_total_steps, "Loss:",loss.item())
def main(): if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if not os.path.exists(args.outdir): os.makedirs(args.outdir) # os.mkdir(args.outdir) train_dataset = get_dataset(args.dataset, 'train') test_dataset = get_dataset(args.dataset, 'test') pin_memory = (args.dataset == "imagenet") train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch, num_workers=args.workers, pin_memory=pin_memory) test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch, num_workers=args.workers, pin_memory=pin_memory) model = get_architecture(args.arch, args.dataset) logfilename = os.path.join(args.outdir, 'log.txt') init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc") criterion = CrossEntropyLoss().cuda() optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma) for epoch in range(args.epochs): scheduler.step(epoch) before = time.time() train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args.noise_sd) test_loss, test_acc = test(test_loader, model, criterion, args.noise_sd) after = time.time() log( logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format( epoch, str(datetime.timedelta(seconds=(after - before))), scheduler.get_lr()[0], train_loss, train_acc, test_loss, test_acc)) torch.save( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
def save_checkpoints(epoch: int, model: Module, optimizer: SGD, loss: _Loss, path: str): torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, path)
def main(): cudnn.benchmark = True batch_size = Config.gpu_count * Config.image_per_gpu EPOCHS = Config.epoch workers = Config.workers global best_val_acc, best_test_acc Config.distributed = Config.gpu_count > 4 # TODO! model = set_model() #if Config.gpu is not None: model = model.cuda() if Config.gpu_count > 1: model = torch.nn.DataParallel(model).cuda() criterion = nn.CrossEntropyLoss().cuda() #weights = torch.FloatTensor(np.array([0.7, 0.3])).cuda() #criterion = WeightCrossEntropy(num_classes=Config.out_class, weight=weights).cuda() #criterion = LGMLoss(num_classes=Config.out_class, feat_dim=128).cuda() optimizer = SGD(model.parameters(), lr=Config.lr, momentum=0.9,nesterov=True, weight_decay=0.0001) #optimizer = Adam(model.parameters()) train_dir = os.path.join(DATA_DIR, 'train', '40X') val_dir = os.path.join(DATA_DIR, 'val', '40X') test_dir = os.path.join(DATA_DIR, 'test', '40X') TRANSFORM_IMG = transforms.Compose([ transforms.Resize((256, 256)), #ImageTransform(), #lambda x: PIL.Image.fromarray(x), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.2, 0.2, 0.2]) ]) train_loader = DataLoader(ImageFolder(root=train_dir, transform=TRANSFORM_IMG), batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=workers) val_loader = DataLoader(ImageFolder(root=test_dir, transform=TRANSFORM_IMG), batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=workers) #test_loader = DataLoader(ImageFolder(root=test_dir, transform=TRANSFORM_IMG), # batch_size=batch_size, shuffle=True, pin_memory=True, # num_workers=workers) for epoch in range(EPOCHS): adjust_learing_rate(optimizer, epoch) train_losses, train_acc = train_epoch(train_loader, model, criterion, optimizer, epoch) val_losses, val_acc = validate(val_loader, model, criterion) is_best = val_acc.avg > best_val_acc print('>>>>>>>>>>>>>>>>>>>>>>') print('Epoch: {} train loss: {}, train acc: {}, valid loss: {}, valid acc: {}'.format(epoch, train_losses.avg, train_acc.avg, val_losses.avg, val_acc.avg)) print('>>>>>>>>>>>>>>>>>>>>>>') save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_val_acc': best_val_acc, 'optimizer': optimizer.state_dict(),}, is_best)
def main(args): run = RunManager(args, ignore=('device', 'evaluate', 'no_cuda'), main='model') print(run) train_dataset = DatasetFolder('data/train', load_sample, ('.npy', ), transform=normalize_sample) val_dataset = DatasetFolder('data/val', load_sample, ('.npy', ), transform=normalize_sample) print(train_dataset) print(val_dataset) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=8, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=8) if args.model == '1d-conv': model = RectCNN(282) else: model = PaperCNN() model = model.double().to(args.device) optimizer = SGD(model.parameters(), lr=1e-2) # evaluate(val_loader, model, args) best = 0 progress = trange(1, args.epochs) for epoch in progress: progress.set_description('TRAIN [CurBestAcc={:.2%}]'.format(best)) train(train_loader, model, optimizer, args) progress.set_description('EVAL [CurBestAcc={:.2%}]'.format(best)) metrics = evaluate(val_loader, model, args) is_best = metrics['acc'] > best best = max(metrics['acc'], best) if is_best: run.save_checkpoint( { 'epoch': epoch, 'params': vars(args), 'model': model.state_dict(), 'optim': optimizer.state_dict(), 'metrics': metrics }, is_best) metrics.update({'epoch': epoch}) run.pushLog(metrics)
def main(): if args['gpu']: os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu'] if not os.path.exists(args['outdir']): os.mkdir(args['outdir']) train_loader, test_loader = loaddata(args) if torch.cuda.is_available(): model = loadmodel(args) model = model.cuda() logfilename = os.path.join(args['outdir'], 'log.txt') init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc") criterion = CrossEntropyLoss().cuda() optimizer = SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay']) scheduler = StepLR(optimizer, step_size=args['lr_step_size'], gamma=args['gamma']) for epoch in range(args['epochs']): scheduler.step(epoch) before = time.time() train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args['noise_sd']) test_loss, test_acc = test(test_loader, model, criterion, args['noise_sd']) after = time.time() log( logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format( epoch, str(datetime.timedelta(seconds=(after - before))), scheduler.get_lr()[0], train_loss, train_acc, test_loss, test_acc)) torch.save( { 'epoch': epoch + 1, 'dataset': args['dataset'], 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args['outdir'], 'checkpoint.pth.tar'))
def main_train_worker(args): if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) print("=> creating model '{}'".format(args.arch)) network = MetaLearnerModelBuilder.construct_cifar_model( args.arch, args.dataset) model_path = '{}/train_pytorch_model/real_image_model/{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format( PY_ROOT, args.dataset, args.arch, args.epochs, args.lr, args.batch_size) os.makedirs(os.path.dirname(model_path), exist_ok=True) print("after train, model will be saved to {}".format(model_path)) network.cuda() image_classifier_loss = nn.CrossEntropyLoss().cuda() optimizer = SGD(network.parameters(), args.lr, weight_decay=args.weight_decay) cudnn.benchmark = True train_loader = DataLoaderMaker.get_img_label_data_loader( args.dataset, args.batch_size, True) val_loader = DataLoaderMaker.get_img_label_data_loader( args.dataset, args.batch_size, False) for epoch in range(0, args.epochs): # adjust_learning_rate(optimizer, epoch, args) # train_simulate_grad_mode for one epoch train(train_loader, network, image_classifier_loss, optimizer, epoch, args) # evaluate_accuracy on validation set validate(val_loader, network, image_classifier_loss, args) # remember best acc@1 and save checkpoint save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), }, filename=model_path)
def train(cont=False): # for tensorboard tracking logger = get_logger() logger.info("(1) Initiating Training ... ") logger.info("Training on device: {}".format(device)) writer = SummaryWriter() # init model aux_layers = None if net == "SETR-PUP": aux_layers, model = get_SETR_PUP() elif net == "SETR-MLA": aux_layers, model = get_SETR_MLA() elif net == "TransUNet-Base": model = get_TransUNet_base() elif net == "TransUNet-Large": model = get_TransUNet_large() elif net == "UNet": model = UNet(CLASS_NUM) # prepare dataset cluster_model = get_clustering_model(logger) train_dataset = CityscapeDataset(img_dir=data_dir, img_dim=IMG_DIM, mode="train", cluster_model=cluster_model) valid_dataset = CityscapeDataset(img_dir=data_dir, img_dim=IMG_DIM, mode="val", cluster_model=cluster_model) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) logger.info("(2) Dataset Initiated. ") # optimizer epochs = epoch_num if epoch_num > 0 else iteration_num // len( train_loader) + 1 optim = SGD(model.parameters(), lr=lrate, momentum=momentum, weight_decay=wdecay) # optim = Adam(model.parameters(), lr=lrate) scheduler = lr_scheduler.MultiStepLR( optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1) cur_epoch = 0 best_loss = float('inf') epochs_since_improvement = 0 # for continue training if cont: model, optim, cur_epoch, best_loss = load_ckpt_continue_training( best_ckpt_src, model, optim, logger) logger.info("Current best loss: {0}".format(best_loss)) with warnings.catch_warnings(): warnings.simplefilter("ignore") for i in range(cur_epoch): scheduler.step() else: model = nn.DataParallel(model) model = model.to(device) logger.info("(3) Model Initiated ... ") logger.info("Training model: {}".format(net) + ". Training Started.") # loss ce_loss = CrossEntropyLoss() if use_dice_loss: dice_loss = DiceLoss(CLASS_NUM) # loop over epochs iter_count = 0 epoch_bar = tqdm.tqdm(total=epochs, desc="Epoch", position=cur_epoch, leave=True) logger.info("Total epochs: {0}. Starting from epoch {1}.".format( epochs, cur_epoch + 1)) for e in range(epochs - cur_epoch): epoch = e + cur_epoch # Training. model.train() trainLossMeter = LossMeter() train_batch_bar = tqdm.tqdm(total=len(train_loader), desc="TrainBatch", position=0, leave=True) for batch_num, (orig_img, mask_img) in enumerate(train_loader): orig_img, mask_img = orig_img.float().to( device), mask_img.float().to(device) if net == "TransUNet-Base" or net == "TransUNet-Large": pred = model(orig_img) elif net == "SETR-PUP" or net == "SETR-MLA": if aux_layers is not None: pred, _ = model(orig_img) else: pred = model(orig_img) elif net == "UNet": pred = model(orig_img) loss_ce = ce_loss(pred, mask_img[:].long()) if use_dice_loss: loss_dice = dice_loss(pred, mask_img, softmax=True) loss = 0.5 * (loss_ce + loss_dice) else: loss = loss_ce # Backward Propagation, Update weight and metrics optim.zero_grad() loss.backward() optim.step() # update learning rate for param_group in optim.param_groups: orig_lr = param_group['lr'] param_group['lr'] = orig_lr * (1.0 - iter_count / iteration_num)**0.9 iter_count += 1 # Update loss trainLossMeter.update(loss.item()) # print status if (batch_num + 1) % print_freq == 0: status = 'Epoch: [{0}][{1}/{2}]\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter) logger.info(status) # log loss to tensorboard if (batch_num + 1) % tensorboard_freq == 0: writer.add_scalar( 'Train_Loss_{0}'.format(tensorboard_freq), trainLossMeter.avg, epoch * (len(train_loader) / tensorboard_freq) + (batch_num + 1) / tensorboard_freq) train_batch_bar.update(1) writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch) # Validation. model.eval() validLossMeter = LossMeter() valid_batch_bar = tqdm.tqdm(total=len(valid_loader), desc="ValidBatch", position=0, leave=True) with torch.no_grad(): for batch_num, (orig_img, mask_img) in enumerate(valid_loader): orig_img, mask_img = orig_img.float().to( device), mask_img.float().to(device) if net == "TransUNet-Base" or net == "TransUNet-Large": pred = model(orig_img) elif net == "SETR-PUP" or net == "SETR-MLA": if aux_layers is not None: pred, _ = model(orig_img) else: pred = model(orig_img) elif net == "UNet": pred = model(orig_img) loss_ce = ce_loss(pred, mask_img[:].long()) if use_dice_loss: loss_dice = dice_loss(pred, mask_img, softmax=True) loss = 0.5 * (loss_ce + loss_dice) else: loss = loss_ce # Update loss validLossMeter.update(loss.item()) # print status if (batch_num + 1) % print_freq == 0: status = 'Validation: [{0}][{1}/{2}]\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter) logger.info(status) # log loss to tensorboard if (batch_num + 1) % tensorboard_freq == 0: writer.add_scalar( 'Valid_Loss_{0}'.format(tensorboard_freq), validLossMeter.avg, epoch * (len(valid_loader) / tensorboard_freq) + (batch_num + 1) / tensorboard_freq) valid_batch_bar.update(1) valid_loss = validLossMeter.avg writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch) logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format( epoch + 1, epochs, valid_loss)) # update optim scheduler scheduler.step() # save checkpoint is_best = valid_loss < best_loss best_loss_tmp = min(valid_loss, best_loss) if not is_best: epochs_since_improvement += 1 logger.info("Epochs since last improvement: %d\n" % (epochs_since_improvement, )) if epochs_since_improvement == early_stop_tolerance: break # early stopping. else: epochs_since_improvement = 0 state = { 'epoch': epoch, 'loss': best_loss_tmp, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optim.state_dict(), } torch.save(state, ckpt_src) logger.info("Checkpoint updated.") best_loss = best_loss_tmp epoch_bar.update(1) writer.close()
def train(start_path, beta): # prepare hyper-parameters seed = 42 cuda_enabled = True cuda_deterministic = False batch_size = 2048 num_workers = 2 shared = False stochastic = False kkt_momentum = 0.0 create_graph = False grad_correction = False shift = 0.0 tol = 1e-5 damping = 0.1 maxiter = 50 lr = 0.1 momentum = 0.0 weight_decay = 0.0 num_steps = 10 verbose = False # prepare path ckpt_name = start_path.name.split('.')[0] root_path = Path(__file__).resolve().parent dataset_path = root_path / 'MultiMNIST' ckpt_path = root_path / 'cpmtl' / ckpt_name if not start_path.is_file(): raise RuntimeError('Pareto solutions not found.') root_path.mkdir(parents=True, exist_ok=True) dataset_path.mkdir(parents=True, exist_ok=True) ckpt_path.mkdir(parents=True, exist_ok=True) # fix random seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if cuda_enabled and torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # prepare device if cuda_enabled and torch.cuda.is_available(): import torch.backends.cudnn as cudnn device = torch.device('cuda') if cuda_deterministic: cudnn.benchmark = False cudnn.deterministic = True else: cudnn.benchmark = True else: device = torch.device('cpu') # prepare dataset transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) trainset = MultiMNIST(dataset_path, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) testset = MultiMNIST(dataset_path, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) # prepare network network = MultiLeNet() network.to(device) # initialize network start_ckpt = torch.load(start_path, map_location='cpu') network.load_state_dict(start_ckpt['state_dict']) # prepare losses criterion = F.cross_entropy closures = [ lambda n, l, t: criterion(l[0], t[:, 0]), lambda n, l, t: criterion(l[1], t[:, 1]) ] # prepare HVP solver hvp_solver = VisionHVPSolver(network, device, trainloader, closures, shared=shared) hvp_solver.set_grad(batch=False) hvp_solver.set_hess(batch=True) # prepare KKT solver kkt_solver = MINRESKKTSolver(network, hvp_solver, device, stochastic=stochastic, kkt_momentum=kkt_momentum, create_graph=create_graph, grad_correction=grad_correction, shift=shift, tol=tol, damping=damping, maxiter=maxiter) # prepare optimizer optimizer = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) # first evaluation losses, tops = evaluate(network, testloader, device, closures, f'{ckpt_name}') # prepare utilities top_trace = TopTrace(len(closures)) top_trace.print(tops, show=False) beta = beta.to(device) # training for step in range(1, num_steps + 1): network.train(True) optimizer.zero_grad() kkt_solver.backward(beta, verbose=verbose) optimizer.step() losses, tops = evaluate(network, testloader, device, closures, f'{ckpt_name}: {step}/{num_steps}') top_trace.print(tops) ckpt = { 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), 'beta': beta, } record = {'losses': losses, 'tops': tops} ckpt['record'] = record torch.save(ckpt, ckpt_path / f'{step:d}.pth') hvp_solver.close()
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.patience = config.patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # initialize optimizer and scheduler self.optimizer = SGD( self.model.parameters(), lr=self.lr, momentum=self.momentum, ) self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=self.patience) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # reduce lr if validation loss plateaus self.scheduler.step(valid_loss) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.mean(-log_pi * adjusted_reward) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data[0], x.size()[0]) accs.update(acc.data[0], x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data[0], acc.data[0]))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.mean(-log_pi * adjusted_reward) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data[0], x.size()[0]) accs.update(acc.data[0], x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x, volatile=True), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100. * correct) / (self.num_test) print('[*] Test Acc: {}/{} ({:.2f}%)'.format(correct, self.num_test, perc)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'] + 1, ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch'] + 1))
loss.backward() optimizer.step() trian_loss += loss.item() / len(train_loader) _, y_pred = torch.max(y_pred, 1) trian_acc += (y_pred == batch_y).sum().item() / len(y) train_loss_temp.append(trian_loss) train_acc_temp.append(trian_acc) if epoch % checkpoint == 0: torch.save(model.state_dict(), f"./data/prob{prob_num}_model_ckpt_{epoch}.bin") torch.save(optimizer.state_dict(), f"./data/prob{prob_num}_optimizer_ckpt_{epoch}.bin") model = model.eval() with torch.no_grad(): y_pred = model(X_valid) loss = loss_fn(y_pred, y_valid) _, y_pred = torch.max(y_pred, 1) acc = (y_pred == y_valid).sum().item() / len(y_valid) valid_loss_temp.append(loss) valid_acc_temp.append(acc) train_loss_all.append(train_loss_temp) train_acc_all.append(train_acc_temp) valid_loss_all.append(valid_loss_temp)
class Trainer: def __init__(self, model: nn.Module, dataset_root: str, summary_writer: SummaryWriter, device: Device, batch_size: int = 128, cc_loss: bool = False): # load train/test splits of SALICON dataset train_dataset = Salicon(dataset_root + "train.pkl") test_dataset = Salicon(dataset_root + "val.pkl") self.train_loader = DataLoader( train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=1, ) self.val_loader = DataLoader( test_dataset, shuffle=False, batch_size=batch_size, num_workers=1, pin_memory=True, ) self.model = model.to(device) self.device = device if cc_loss: self.criterion = CCLoss else: self.criterion = nn.MSELoss() self.optimizer = SGD(self.model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005, nesterov=True) self.summary_writer = summary_writer self.step = 0 def train(self, epochs: int, val_frequency: int, log_frequency: int = 5, start_epoch: int = 0): lrs = np.linspace(0.03, 0.0001, epochs) for epoch in range(start_epoch, epochs): self.model.train() for batch, gts in self.train_loader: # LR decay # need to update learning rate between 0.03 and 0.0001 (according to paper) optimstate = self.optimizer.state_dict() self.optimizer = SGD(self.model.parameters(), lr=lrs[epoch], momentum=0.9, weight_decay=0.0005, nesterov=True) self.optimizer.load_state_dict(optimstate) self.optimizer.zero_grad() # load batch to device batch = batch.to(self.device) gts = gts.to(self.device) # train step step_start_time = time.time() output = self.model.forward(batch) loss = self.criterion(output, gts) loss.backward() self.optimizer.step() # log step if ((self.step + 1) % log_frequency) == 0: step_time = time.time() - step_start_time self.log_metrics(epoch, loss, step_time) self.print_metrics(epoch, loss, step_time) # count steps self.step += 1 # log epoch self.summary_writer.add_scalar("epoch", epoch, self.step) # validate if ((epoch + 1) % val_frequency) == 0: self.validate() self.model.train() if (epoch + 1) % 10 == 0: save(self.model, "checkp_model.pkl") def print_metrics(self, epoch, loss, step_time): epoch_step = self.step % len(self.train_loader) print(f"epoch: [{epoch}], " f"step: [{epoch_step}/{len(self.train_loader)}], " f"batch loss: {loss:.5f}, " f"step time: {step_time:.5f}") def log_metrics(self, epoch, loss, step_time): self.summary_writer.add_scalar("epoch", epoch, self.step) self.summary_writer.add_scalars("loss", {"train": float(loss.item())}, self.step) self.summary_writer.add_scalar("time/data", step_time, self.step) def validate(self): results = {"preds": [], "gts": []} total_loss = 0 self.model.eval() # No need to track gradients for validation, we're not optimizing. with no_grad(): for batch, gts in self.val_loader: batch = batch.to(self.device) gts = gts.to(self.device) output = self.model(batch) loss = self.criterion(output, gts) total_loss += loss.item() preds = output.cpu().numpy() results["preds"].extend(list(preds)) results["gts"].extend(list(gts.cpu().numpy())) average_loss = total_loss / len(self.val_loader) self.summary_writer.add_scalars("loss", {"test": average_loss}, self.step) print(f"validation loss: {average_loss:.5f}")
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = T.Compose([ T.RandomRotation(args.rotation), T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale), T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25), T.GaussianBlur(), T.ToTensor(), normalize ]) val_transform = T.Compose( [T.Resize(args.image_size), T.ToTensor(), normalize]) image_size = (args.image_size, args.image_size) heatmap_size = (args.heatmap_size, args.heatmap_size) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) print("Source train:", len(train_source_loader)) print("Target train:", len(train_target_loader)) print("Source test:", len(val_source_loader)) print("Target test:", len(val_target_loader)) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model backbone = models.__dict__[args.arch](pretrained=True) upsampling = Upsampling(backbone.out_features) num_keypoints = train_source_dataset.num_keypoints model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device) # define loss function criterion = JointsKLLoss() pseudo_label_generator = PseudoLabelGenerator(num_keypoints, args.heatmap_size, args.heatmap_size) regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7)) # define optimizer and lr scheduler optimizer_f = SGD([ { 'params': backbone.parameters(), 'lr': 0.1 }, { 'params': upsampling.parameters(), 'lr': 0.1 }, ], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x))**( -args.lr_decay) lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function) lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function) lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function) start_epoch = 0 if args.resume is None: if args.pretrain is None: # first pretrain the backbone and upsampling print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device) optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor) best_acc = 0 for epoch in range(args.pretrain_epochs): lr_scheduler.step() print(lr_scheduler.get_lr()) pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args) source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args) # remember best acc and save checkpoint if source_val_acc['all'] > best_acc: best_acc = source_val_acc['all'] torch.save({'model': pretrained_model.state_dict()}, args.pretrain) print("Source: {} best: {}".format(source_val_acc['all'], best_acc)) # load from the pretrained checkpoint pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model'] model_dict = model.state_dict() # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model.load_state_dict(pretrained_dict, strict=False) else: # optionally resume from a checkpoint checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer_f.load_state_dict(checkpoint['optimizer_f']) optimizer_h.load_state_dict(checkpoint['optimizer_h']) optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv']) lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f']) lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h']) lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv']) start_epoch = checkpoint['epoch'] + 1 # define visualization function tensor_to_image = Compose([ Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ToPILImage() ]) def visualize(image, keypoint2d, name, heatmaps=None): """ Args: image (tensor): image in shape 3 x H x W keypoint2d (tensor): keypoints in shape K x 2 name: name of the saving image """ train_source_dataset.visualize( tensor_to_image(image), keypoint2d, logger.get_image_path("{}.jpg".format(name))) if args.phase == 'test': # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize, args) print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all'])) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) return # start training best_acc = 0 print("Start regression domain adaptation.") for epoch in range(start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, criterion, regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch, visualize if args.debug else None, args) # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args) # remember best acc and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer_f': optimizer_f.state_dict(), 'optimizer_h': optimizer_h.state_dict(), 'optimizer_h_adv': optimizer_h_adv.state_dict(), 'lr_scheduler_f': lr_scheduler_f.state_dict(), 'lr_scheduler_h': lr_scheduler_h.state_dict(), 'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if target_val_acc['all'] > best_acc: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_acc = target_val_acc['all'] print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format( source_val_acc['all'], target_val_acc['all'], best_acc)) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) logger.close()
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() train_transform = transforms.Compose([ transforms.Resize((args.scale_size, args.scale_size)), transforms.RandomChoice([ transforms.RandomCrop(640), transforms.RandomCrop(576), transforms.RandomCrop(512), transforms.RandomCrop(384), transforms.RandomCrop(320) ]), transforms.Resize((args.crop_size, args.crop_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = MLDataset(args.train_path, args.label_path, train_transform) self.train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) val_transform = transforms.Compose([ transforms.Resize((args.scale_size, args.scale_size)), transforms.CenterCrop(args.crop_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_dataset = MLDataset(args.val_path, args.label_path, val_transform) self.val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) self.model = model_factory[args.model](args, args.num_classes) self.model.cuda() trainable_parameters = filter(lambda param: param.requires_grad, self.model.parameters()) if args.optimizer == 'Adam': self.optimizer = Adam(trainable_parameters, lr=args.lr) elif args.optimizer == 'SGD': self.optimizer = SGD(trainable_parameters, lr=args.lr) self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', patience=2, verbose=True) if args.loss == 'BCElogitloss': self.criterion = nn.BCEWithLogitsLoss() elif args.loss == 'tencentloss': self.criterion = TencentLoss(args.num_classes) elif args.loss == 'focalloss': self.criterion = FocalLoss() self.early_stopping = EarlyStopping(patience=5) self.voc12_mAP = VOC12mAP(args.num_classes) self.average_loss = AverageLoss() self.average_topk_meter = TopkAverageMeter(args.num_classes, topk=args.topk) self.average_threshold_meter = ThresholdAverageMeter( args.num_classes, threshold=args.threshold) self.args = args self.global_step = 0 self.writer = SummaryWriter(log_dir=args.log_dir) def run(self): s_epoch = 0 if self.args.resume: checkpoint = torch.load(self.args.ckpt_latest_path) s_epoch = checkpoint['epoch'] self.global_step = checkpoint['global_step'] self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) self.early_stopping.best_score = checkpoint['best_score'] print('loading checkpoint success (epoch {})'.format(s_epoch)) for epoch in range(s_epoch, self.args.max_epoch): self.train(epoch) save_dict = { 'epoch': epoch + 1, 'global_step': self.global_step, 'model_state_dict': self.model.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'best_score': self.early_stopping.best_score } torch.save(save_dict, self.args.ckpt_latest_path) mAP = self.validation(epoch) self.lr_scheduler.step(mAP) is_save, is_terminate = self.early_stopping(mAP) if is_terminate: break if is_save: torch.save(self.model.state_dict(), self.args.ckpt_best_path) def train(self, epoch): self.model.train() if self.args.model == 'ssgrl': self.model.resnet_101.eval() self.model.resnet_101.layer4.train() for _, batch in enumerate(self.train_loader): x, y = batch[0].cuda(), batch[1].cuda() pred_y = self.model(x) loss = self.criterion(pred_y, y) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.global_step % 400 == 0: self.writer.add_scalar('Loss/train', loss, self.global_step) print('TRAIN [epoch {}] loss: {:4f}'.format(epoch, loss)) self.global_step += 1 def validation(self, epoch): self.model.eval() self.voc12_mAP.reset() self.average_loss.reset() self.average_topk_meter.reset() self.average_threshold_meter.reset() with torch.no_grad(): for _, batch in enumerate(self.val_loader): x, y = batch[0].cuda(), batch[1].cuda() pred_y = self.model(x) loss = self.criterion(pred_y, y) y = y.cpu().numpy() pred_y = pred_y.cpu().numpy() loss = loss.cpu().numpy() self.voc12_mAP.update(pred_y, y) self.average_loss.update(loss, x.size(0)) self.average_topk_meter.update(pred_y, y) self.average_threshold_meter.update(pred_y, y) _, mAP = self.voc12_mAP.compute() mLoss = self.average_loss.compute() self.average_topk_meter.compute() self.average_threshold_meter.compute() self.writer.add_scalar('Loss/val', mLoss, self.global_step) self.writer.add_scalar('mAP/val', mAP, self.global_step) print("Validation [epoch {}] mAP: {:.4f} loss: {:.4f}".format( epoch, mAP, mLoss)) return mAP
optimizer.step() print(torch.cuda.max_memory_allocated(device=0)) print(output) with open('./Logs/log_triplet_new.txt', 'a') as f: val_list = [ epoch + 1, batch_idx, float(output), float(triplet_loss_sum) ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') avg_triplet_loss = triplet_loss_sum / batches_per_epoch with open('./Logs/log_triplet_new.txt', 'a') as f: val_list = ['FINAL', epoch + 1, float(avg_triplet_loss)] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') print('Epoch {}:\tAverage Triplet Loss: {:.4f}\t'.format( epoch + 1, avg_triplet_loss)) torch.save( { 'epoch': epoch, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'avg_triplet_loss': avg_triplet_loss }, './Train_Checkpoints/' + 'checkpoint_' + str(epoch) + '_' + str(round(float(avg_triplet_loss), 4)) + '.tar')
class MetaFrameWork(object): def __init__(self, name='normal_all', train_num=1, source='GSIM', target='C', network=Net, resume=True, dataset=DGMetaDataSets, inner_lr=1e-3, outer_lr=5e-3, train_size=8, test_size=16, no_source_test=True, bn='torch'): super(MetaFrameWork, self).__init__() self.no_source_test = no_source_test self.train_num = train_num self.exp_name = name self.resume = resume self.inner_update_lr = inner_lr self.outer_update_lr = outer_lr self.network = network self.dataset = dataset self.train_size = train_size self.test_size = test_size self.source = source self.target = target self.bn = bn self.epoch = 1 self.best_target_acc = 0 self.best_target_acc_source = 0 self.best_target_epoch = 1 self.best_source_acc = 0 self.best_source_acc_target = 0 self.best_source_epoch = 0 self.total_epoch = 120 self.save_interval = 1 self.save_path = Path(self.exp_name) self.init() def init(self): kwargs = {'bn': self.bn, 'output_stride': 8} self.backbone = nn.DataParallel(self.network(**kwargs)).cuda() kwargs.update({'pretrained': False}) self.updated_net = nn.DataParallel(self.network(**kwargs)).cuda() self.ce = nn.CrossEntropyLoss(ignore_index=-1) self.nim = NaturalImageMeasure(nclass=19) batch_size = self.train_size workers = len(self.source) * 4 dataloader = functools.partial(DataLoader, num_workers=workers, pin_memory=True, batch_size=batch_size, shuffle=True) self.train_loader = dataloader(self.dataset(mode='train', domains=self.source, force_cache=True)) dataloader = functools.partial(DataLoader, num_workers=workers, pin_memory=True, batch_size=self.test_size, shuffle=False) self.source_val_loader = dataloader(self.dataset(mode='val', domains=self.source, force_cache=True)) target_dataset, folder = get_dataset(self.target) self.target_loader = dataloader(target_dataset(root=ROOT + folder, mode='val')) self.target_test_loader = dataloader(target_dataset(root=ROOT + 'cityscapes', mode='test')) self.opt_old = SGD(self.backbone.parameters(), lr=self.outer_update_lr, momentum=0.9, weight_decay=5e-4) self.scheduler_old = PolyLR(self.opt_old, self.total_epoch, len(self.train_loader), 0, True, power=0.9) self.logger = get_logger('train', self.exp_name) self.log('exp_name : {}, train_num = {}, source domains = {}, target_domain = {}, lr : inner = {}, outer = {},' 'dataset : {}, net : {}, bn : {}\n'. format(self.exp_name, self.train_num, self.source, self.target, self.inner_update_lr, self.outer_update_lr, self.dataset, self.network, self.bn)) self.log(self.exp_name + '\n') self.train_timer, self.test_timer = Timer(), Timer() def train(self, epoch, it, inputs): # imgs : batch x domains x C x H x W # targets : batch x domains x 1 x H x W imgs, targets = inputs B, D, C, H, W = imgs.size() meta_train_imgs = imgs.view(-1, C, H, W) meta_train_targets = targets.view(-1, 1, H, W) tr_logits = self.backbone(meta_train_imgs)[0] tr_logits = make_same_size(tr_logits, meta_train_targets) ds_loss = self.ce(tr_logits, meta_train_targets[:, 0]) with torch.no_grad(): self.nim(tr_logits, meta_train_targets) self.opt_old.zero_grad() ds_loss.backward() self.opt_old.step() self.scheduler_old.step(epoch, it) losses = { 'dg': 0, 'ds': ds_loss.item() } acc = { 'iou': self.nim.get_res()[0] } return losses, acc, self.scheduler_old.get_lr(epoch, it)[0] def meta_train(self, epoch, it, inputs): # imgs : batch x domains x C x H x W # targets : batch x domains x 1 x H x W imgs, targets = inputs B, D, C, H, W = imgs.size() split_idx = np.random.permutation(D) i = np.random.randint(1, D) train_idx = split_idx[:i] test_idx = split_idx[i:] # train_idx = split_idx[:D // 2] # test_idx = split_idx[D // 2:] # self.print(split_idx, B, D, C, H, W)' meta_train_imgs = imgs[:, train_idx].reshape(-1, C, H, W) meta_train_targets = targets[:, train_idx].reshape(-1, 1, H, W) meta_test_imgs = imgs[:, test_idx].reshape(-1, C, H, W) meta_test_targets = targets[:, test_idx].reshape(-1, 1, H, W) # Meta-Train tr_logits = self.backbone(meta_train_imgs)[0] tr_logits = make_same_size(tr_logits, meta_train_targets) ds_loss = self.ce(tr_logits, meta_train_targets[:, 0]) # Update new network self.opt_old.zero_grad() ds_loss.backward(retain_graph=True) updated_net = get_updated_network(self.backbone, self.updated_net, self.inner_update_lr).train().cuda() # Meta-Test te_logits = updated_net(meta_test_imgs)[0] # te_logits = test_res[0] te_logits = make_same_size(te_logits, meta_test_targets) dg_loss = self.ce(te_logits, meta_test_targets[:, 0]) with torch.no_grad(): self.nim(te_logits, meta_test_targets) # Update old network dg_loss.backward() self.opt_old.step() self.scheduler_old.step(epoch, it) losses = { 'dg': dg_loss.item(), 'ds': ds_loss.item() } acc = { 'iou': self.nim.get_res()[0], } return losses, acc, self.scheduler_old.get_lr(epoch, it)[0] def do_train(self): if self.resume: self.load() self.writer = SummaryWriter(str(self.save_path / 'tensorboard'), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) self.log('Start epoch : {}\n'.format(self.epoch)) for epoch in range(self.epoch, self.total_epoch + 1): loss_meters, acc_meters = MeterDicts(), MeterDicts(averaged=['iou']) self.nim.clear_cache() self.backbone.train() self.epoch = epoch with self.train_timer: for it, (paths, imgs, target) in enumerate(self.train_loader): meta = (it + 1) % self.train_num == 0 if meta: losses, acc, lr = self.meta_train(epoch - 1, it, to_cuda([imgs, target])) else: losses, acc, lr = self.train(epoch - 1, it, to_cuda([imgs, target])) loss_meters.update_meters(losses, skips=['dg'] if not meta else []) acc_meters.update_meters(acc) self.print(self.get_string(epoch, it, loss_meters, acc_meters, lr, meta), end='') self.tfb_log(epoch, it, loss_meters, acc_meters) self.print(self.train_timer.get_formatted_duration()) self.log(self.get_string(epoch, it, loss_meters, acc_meters, lr, meta) + '\n') self.save('ckpt') if epoch % self.save_interval == 0: with self.test_timer: city_acc = self.val(self.target_loader) self.save_best(city_acc, epoch) total_duration = self.train_timer.duration + self.test_timer.duration self.print('Time Left : ' + self.train_timer.get_formatted_duration(total_duration * (self.total_epoch - epoch)) + '\n') self.log('Best city acc : \n city : {}, origin : {}, epoch : {}\n'.format( self.best_target_acc, self.best_target_acc_source, self.best_target_epoch)) self.log('Best origin acc : \n city : {}, origin : {}, epoch : {}\n'.format( self.best_source_acc_target, self.best_source_acc, self.best_source_epoch)) def save_best(self, city_acc, epoch): self.writer.add_scalar('acc/citys', city_acc, epoch) if not self.no_source_test: origin_acc = self.val(self.source_val_loader) self.writer.add_scalar('acc/origin', origin_acc, epoch) else: origin_acc = 0 self.writer.flush() if city_acc > self.best_target_acc: self.best_target_acc = city_acc self.best_target_acc_source = origin_acc self.best_target_epoch = epoch self.save('best_city') if origin_acc > self.best_source_acc: self.best_source_acc = origin_acc self.best_source_acc_target = city_acc self.best_source_epoch = epoch self.save('best_origin') def val(self, dataset): self.backbone.eval() with torch.no_grad(): self.nim.clear_cache() self.nim.set_max_len(len(dataset)) for p, img, target in dataset: img, target = to_cuda(get_img_target(img, target)) logits = self.backbone(img)[0] self.nim(logits, target) self.log('\nNormal validation : {}\n'.format(self.nim.get_acc())) if hasattr(dataset.dataset, 'format_class_iou'): self.log(dataset.dataset.format_class_iou(self.nim.get_class_acc()[0]) + '\n') return self.nim.get_acc()[0] def target_specific_val(self, loader): self.nim.clear_cache() self.nim.set_max_len(len(loader)) # eval for dropout self.backbone.module.remove_dropout() self.backbone.module.not_track() for idx, (p, img, target) in enumerate(loader): if len(img.size()) == 5: B, D, C, H, W = img.size() else: B, C, H, W = img.size() D = 1 img, target = to_cuda([img.reshape(B, D, C, H, W), target.reshape(B, D, 1, H, W)]) for d in range(img.size(1)): img_d, target_d, = img[:, d], target[:, d] self.backbone.train() with torch.no_grad(): new_logits = self.backbone(img_d)[0] self.nim(new_logits, target_d) self.backbone.module.recover_dropout() self.log('\nTarget specific validation : {}\n'.format(self.nim.get_acc())) if hasattr(loader.dataset, 'format_class_iou'): self.log(loader.dataset.format_class_iou(self.nim.get_class_acc()[0]) + '\n') return self.nim.get_acc()[0] def predict_target(self, load_path='best_city', color=False, train=False, output_path='predictions'): self.load(load_path) import skimage.io as skio dataset = self.target_test_loader output_path = Path(self.save_path / output_path) output_path.mkdir(exist_ok=True) if train: self.backbone.module.remove_dropout() self.backbone.train() else: self.backbone.eval() with torch.no_grad(): self.nim.clear_cache() self.nim.set_max_len(len(dataset)) for names, img, target in tqdm(dataset): img = to_cuda(img) logits = self.backbone(img)[0] logits = F.interpolate(logits, img.size()[2:], mode='bilinear', align_corners=True) preds = get_prediction(logits).cpu().numpy() if color: trainId_preds = preds else: trainId_preds = dataset.dataset.predict(preds) for pred, name in zip(trainId_preds, names): file_name = name.split('/')[-1] if color: pred = class_map_2_color_map(pred).transpose(1, 2, 0).astype(np.uint8) skio.imsave(str(output_path / file_name), pred) def get_string(self, epoch, it, loss_meters, acc_meters, lr, meta): string = '\repoch {:4}, iter : {:4}, '.format(epoch, it) for k, v in loss_meters.items(): string += k + ' : {:.4f}, '.format(v.avg) for k, v in acc_meters.items(): string += k + ' : {:.4f}, '.format(v.avg) string += 'lr : {:.6f}, meta : {}'.format(lr, meta) return string def log(self, strs): self.logger.info(strs) def print(self, strs, **kwargs): print(strs, **kwargs) def tfb_log(self, epoch, it, losses, acc): iteration = epoch * len(self.train_loader) + it for k, v in losses.items(): self.writer.add_scalar('loss/' + k, v.val, iteration) for k, v in acc.items(): self.writer.add_scalar('acc/' + k, v.val, iteration) def save(self, name='ckpt'): info = [self.best_source_acc, self.best_source_acc_target, self.best_source_epoch, self.best_target_acc, self.best_target_acc_source, self.best_target_epoch] dicts = { 'backbone': self.backbone.state_dict(), 'opt': self.opt_old.state_dict(), 'epoch': self.epoch + 1, 'best': self.best_target_acc, 'info': info } self.print('Saving epoch : {}'.format(self.epoch)) torch.save(dicts, self.save_path / '{}.pth'.format(name)) def load(self, path=None, strict=False): if path is None: path = self.save_path / 'ckpt.pth' else: if 'pth' in path: path = path else: path = self.save_path / '{}.pth'.format(path) try: dicts = torch.load(path, map_location='cpu') msg = self.backbone.load_state_dict(dicts['backbone'], strict=strict) self.print(msg) if 'opt' in dicts: self.opt_old.load_state_dict(dicts['opt']) if 'epoch' in dicts: self.epoch = dicts['epoch'] else: self.epoch = 1 if 'best' in dicts: self.best_target_acc = dicts['best'] if 'info' in dicts: self.best_source_acc, self.best_source_acc_target, self.best_source_epoch, \ self.best_target_acc, self.best_target_acc_source, self.best_target_epoch = dicts['info'] self.log('Loaded from {}, next epoch : {}, best_target : {}, best_epoch : {}\n' .format(str(path), self.epoch, self.best_target_acc, self.best_target_epoch)) return True except Exception as e: self.print(e) self.log('No ckpt found in {}\n'.format(str(path))) self.epoch = 1 return False
def main(args): """ Main function Here, you should instantiate 1) Dataset objects for training and test datasets 2) DataLoaders for training and testing 3) model 4) optimizer: SGD with initial learning rate 0.01 and momentum 0.9 5) cost function: use torch.nn.CrossEntropyLoss """ # write your codes here # Configuration mode = args.mode model_name = args.model options = args.o if mode == 'train': train_data_dir = args.d + '/train/' elif mode == 'test': test_data_dir = args.d + '/test/' elif mode == 'graph_compare': if model_name == 'LeNet5': models_name = [ 'LeNet5', 'LeNet5_insert_noise_s0.1_m0.0', 'LeNet5_insert_noise_s0.2_m0.0', 'LeNet5_insert_noise_s0.3_m0.0', 'LeNet5_weight_decay_0.0001', 'LeNet5_weight_decay_0.001', 'LeNet5_weight_decay_0.01' ] elif model_name == 'CustomMLP_6': models_name = [ 'CustomMLP_6', 'CustomMLP_6_weight_decay_1e-05', 'CustomMLP_6_weight_decay_0.0001', 'CustomMLP_6_weight_decay_0.001' ] else: models_name = [ 'LeNet5', 'CustomMLP_1', 'CustomMLP_2', 'CustomMLP_3', 'CustomMLP_4', 'CustomMLP_5', 'CustomMLP_6' ] model_path = args.m device = torch.device("cuda:" + str(args.cuda)) lr = 0.01 momentum = 0.6 batch_size = args.b epoch = args.e use_ckpt = args.c if model_name == "CustomMLP_1": layer_option = [54, 47, 35, 10, 39] elif model_name == "CustomMLP_2": layer_option = [55, 35, 30, 34] elif model_name == "CustomMLP_3": layer_option = [55, 34, 33, 31] elif model_name == "CustomMLP_4": layer_option = [55, 41, 41] elif model_name == "CustomMLP_5": layer_option = [56, 51] elif model_name == "CustomMLP_6": layer_option = [58] ##change models if mode != "graph_compare": if model_name.split('_')[0] == "LeNet5": model = LeNet5(device).to(device) elif model_name.split('_')[0] == "CustomMLP": model = CustomMLP(layer_option).to(device) ##change model name if options: model_name = model_name + '_' + options if options == "weight_decay": weight_decay = args.w gausian_noise_mean = 0. gausian_noise_std = 0. model_name += '_' + str(weight_decay) elif options == "insert_noise": weight_decay = 0. gausian_noise_mean = args.mean gausian_noise_std = args.std model_name += '_s' + str(gausian_noise_std) + "_m" + str( gausian_noise_mean) else: weight_decay = 0. ##change criterion criterion = CrossEntropyLoss() #Custom TimeModule mytime = CheckTime() if mode == "train": # Load Dataset print( "{} Start Loading Train Dataset ===================================" .format(mytime.get_running_time_str())) train_dataset = dataset.MNIST(train_data_dir, gausian_noise_mean, gausian_noise_std) train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True) # initiate optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) # If use checkpoint ... if use_ckpt: ckpt_files = glob(model_path + '{}_model_*.pt'.format(model_name)) ckpt_files.sort() ckpt_model_path = ckpt_files[-1] epoch_info = torch.load(ckpt_model_path, map_location=device) start_epoch = epoch_info['epoch'] - 1 model.load_state_dict(epoch_info['model']) optimizer.load_state_dic(epoch_info['optimizer']) total_trn_loss = epoch_info['total_trn_loss'] total_trn_acc = epoch_info['total_trn_acc'] else: start_epoch = 0 total_trn_loss = [] total_trn_acc = [] # Check Random Parameter Model Loss & Accuracy print( "{} Check Random Parameter Model {}========================================= " .format(mytime.get_running_time_str(), model_name)) with torch.no_grad(): trn_loss, trn_acc = test(model, train_dataloader, device, criterion) total_trn_loss.append(trn_loss.item()) total_trn_acc.append(trn_acc.item()) i = 0 torch.save( { 'epoch': i, 'model': model.state_dict(), 'opimizer': optimizer.state_dict(), 'total_trn_loss': total_trn_loss, 'total_trn_acc': total_trn_acc }, model_path + '{}_model_{:04d}.pt'.format(model_name, i)) print("{} train {} // epoch: {} // loss: {:.6f} // accuracy: {:.2f} ". format(mytime.get_running_time_str(), model_name, i, trn_loss, trn_acc)) # Start traing model print("{} Start Training {}========================================= ". format(mytime.get_running_time_str(), model_name)) for i in range(start_epoch, epoch): trn_loss, trn_acc = train(model, train_dataloader, device, criterion, optimizer) total_trn_loss.append(trn_loss.item()) total_trn_acc.append(trn_acc.item()) torch.save( { 'epoch': i, 'model': model.state_dict(), 'opimizer': optimizer.state_dict(), 'total_trn_loss': total_trn_loss, 'total_trn_acc': total_trn_acc }, model_path + '{}_model_{:04d}.pt'.format(model_name, i + 1)) print( "{} train {} // epoch: {} // loss: {:.6f} // accuracy: {:.2f} " .format(mytime.get_running_time_str(), model_name, i + 1, trn_loss, trn_acc)) if mode == "test": #Start Loading Test Dataset print( "{} Start Loading Test Dataset ===================================" .format(mytime.get_running_time_str())) test_dataset = dataset.MNIST(test_data_dir) test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True) # Start Testing model with torch.no_grad(): ckpt_files = glob(model_path + '{}_model_*.pt'.format(model_name)) ckpt_files.sort() total_tst_loss = [] total_tst_acc = [] for i, ckpt_model_path in enumerate(ckpt_files): epoch_info = torch.load(ckpt_model_path, map_location=device) model.load_state_dict(epoch_info['model']) tst_loss, tst_acc = test(model, test_dataloader, device, criterion) total_tst_loss.append(tst_loss.item()) total_tst_acc.append(tst_acc.item()) epoch_info['total_tst_loss'] = total_tst_loss epoch_info['total_tst_acc'] = total_tst_acc torch.save(epoch_info, ckpt_model_path) print( "{} test {} // model_num: {} // loss: {:.6f} // accuracy: {:.2f} " .format(mytime.get_running_time_str(), model_name, i, tst_loss, tst_acc)) if mode == "graph": #Load models to draw graph ckpt_files = glob(model_path + '{}_model_*.pt'.format(model_name)) ckpt_files.sort() epoch_info = torch.load(ckpt_files[-1]) #initiate loss and accuracy dictionary loss_dic = {} acc_dic = {} #add loss and accuracy list loss_dic['train'] = epoch_info['total_trn_loss'] loss_dic['test'] = epoch_info['total_tst_loss'] acc_dic['train'] = epoch_info['total_trn_acc'] acc_dic['test'] = epoch_info['total_tst_acc'] num_epoch = len(loss_dic['train']) #Draw Graph per model: trn_loss + tst_loss graph_name = "Loss (model - {}) ".format(model_name) draw_model_graph(graph_name, num_epoch, loss_dic, graph_mode="loss", save=args.s, zoom_plot=args.z) #Draw Graph per model: trn_acc + tst_acc graph_name = "Accuracy (model - {}) ".format(model_name) draw_model_graph(graph_name, num_epoch, acc_dic, graph_mode="acc", save=args.s, zoom_plot=args.z) if mode == "graph_compare": tst_loss_dic = {} tst_acc_dic = {} print(models_name) #Load pre-defined models for model_name in models_name: model_file_name = model_path + '{}_model_{:04d}.pt'.format( model_name, epoch) epoch_info = torch.load(model_file_name) tst_loss_dic[model_name] = epoch_info['total_tst_loss'] tst_acc_dic[model_name] = epoch_info['total_tst_acc'] num_epoch = len(tst_loss_dic[model_name]) #Comparison models: tst_loss graph_name = "Compare Loss" draw_model_graph(graph_name, num_epoch, tst_loss_dic, graph_mode="loss", save=args.s, zoom_plot=args.z) #Comparison models: tst_acc graph_name = "Compare Accuracy" draw_model_graph(graph_name, num_epoch, tst_acc_dic, graph_mode="acc", save=args.s, zoom_plot=args.z)
torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best-{}.pth.tar'.format(arch)) model = Model().cuda() optimizer = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr, momentum, weight_decay) criterion = nn.CrossEntropyLoss().cuda() best_loss = 10 for epoch in range(epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set loss = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = loss < best_loss best_loss = min(loss, best_loss) print(' * Best Loss: {}'.format(best_loss)) save_checkpoint( { 'epoch': epoch + 1, 'arch': arch, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), }, is_best)
class Trainer(object): def __init__(self, config): self.config = config self.device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') start_time = datetime.datetime.now().strftime('%m%d_%H%M%S') self.log_path = os.path.join(config['train']['save_dir'], start_time) tb_path = os.path.join(self.log_path, 'logs') mkdir_p(tb_path) self.writer = WriterTensorboardX(tb_path) data_manager = CSVDataManager(config['data']) self.data_loader = data_manager.get_loader('train') self.valid_data_loader = data_manager.get_loader('val') self.model = AttentionalFactorizationMachine(data_manager.dims, config) self.model = self.model.to(self.device) trainable_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = SGD(trainable_params, **config['optimizer']) self.lr_scheduler = StepLR(self.optimizer, **config['lr_scheduler']) self.best_val_loss = float('inf') self.satur_count = 0 def _train_epoch(self, epoch): self.model.train() total_loss = 0 self.writer.set_step(epoch) _trange = tqdm(self.data_loader, leave=True, desc='') for batch_idx, batch in enumerate(_trange): batch = [b.to(self.device) for b in batch] data, target = batch[:-1], batch[-1] # data -> users, items, gens self.optimizer.zero_grad() output = self.model(data) loss = F.mse_loss(output, target) loss.backward() self.optimizer.step() total_loss += loss.item() if batch_idx % 10 == 0: _str = 'Train Epoch: {} Loss: {:.6f}'.format( epoch, loss.item()) _trange.set_description(_str) loss = total_loss / len(self.data_loader) self.writer.add_scalar('loss', loss) log = {'loss': loss} val_log = self._valid_epoch(epoch) log = {**log, **val_log} self.lr_scheduler.step() return log def _valid_epoch(self, epoch): self.model.eval() total_val_loss = 0 self.writer.set_step(epoch, 'valid') with torch.no_grad(): for batch_idx, batch in enumerate(self.valid_data_loader): batch = [b.to(self.device) for b in batch] data, target = batch[:-1], batch[-1] output = self.model(data) loss = F.mse_loss(output, target) total_val_loss += loss.item() val_loss = total_val_loss / len(self.valid_data_loader) self.writer.add_scalar('loss', val_loss) # for name, param in self.model.named_parameters(): # if param.requires_grad: # self.writer.add_histogram(name, param.clone().cpu().numpy(), bins='doane') return {'val_loss': val_loss} def train(self): print(self.model) for epoch in range(1, self.config['train']['epochs'] + 1): result = self._train_epoch(epoch) c_lr = self.optimizer.param_groups[0]['lr'] self.writer.add_scalar('lr', c_lr) log = pd.DataFrame([result]).T log.columns = [''] print(log) if self.best_val_loss > result['val_loss']: print('[IMPROVED]') chk_path = os.path.join(self.log_path, 'checkpoints') mkdir_p(chk_path) state = { 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict() } torch.save(state, os.path.join(chk_path, 'model_best.pth')) with open(os.path.join(chk_path, 'config.json'), 'w') as wj: json.dump(self.config, wj) else: self.satur_count += 1 if self.satur_count > self.config['train']['early_stop']: break
class Trainer(object): def __init__(self, config_path=None, **kwargs): # general self.run_name = None # code parameters self.use_ecc = None self.n_symbols = None # channel self.memory_length = None self.channel_type = None self.channel_coefficients = None self.noisy_est_var = None self.fading_in_channel = None self.fading_in_decoder = None self.fading_taps_type = None self.subframes_in_frame = None self.gamma = None # validation hyperparameters self.val_block_length = None self.val_frames = None self.val_SNR_start = None self.val_SNR_end = None self.val_SNR_step = None self.eval_mode = None # training hyperparameters self.train_block_length = None self.train_frames = None self.train_minibatch_num = None self.train_minibatch_size = None self.train_SNR_start = None self.train_SNR_end = None self.train_SNR_step = None self.lr = None # learning rate self.loss_type = None self.optimizer_type = None # self-supervised online training self.self_supervised = None self.self_supervised_iterations = None self.ser_thresh = None self.meta_lr = None self.MAML = None self.online_meta = None self.weights_init = None self.window_size = None self.buffer_empty = None self.meta_train_iterations = None self.meta_j_num = None self.meta_subframes = None # seed self.noise_seed = None self.word_seed = None # weights dir self.weights_dir = None # if any kwargs are passed, initialize the dict with them self.initialize_by_kwargs(**kwargs) # initializes all none parameters above from config self.param_parser(config_path) # initializes word and noise generator from seed self.rand_gen = np.random.RandomState(self.noise_seed) self.word_rand_gen = np.random.RandomState(self.word_seed) self.n_states = 2**self.memory_length # initialize matrices, datasets and detector self.initialize_dataloaders() self.initialize_detector() self.initialize_meta_detector() # calculate data subframes indices. We will calculate ser only over these values. self.data_indices = torch.Tensor( list( filter(lambda x: x % self.subframes_in_frame != 0, [ i for i in range(self.val_frames * self.subframes_in_frame) ]))).long() def initialize_by_kwargs(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) def param_parser(self, config_path: str): """ Parse the config, load all attributes into the trainer :param config_path: path to config """ if config_path is None: config_path = CONFIG_PATH with open(config_path) as f: self.config = yaml.load(f, Loader=yaml.FullLoader) # set attribute of Trainer with every config item for k, v in self.config.items(): try: if getattr(self, k) is None: setattr(self, k, v) except AttributeError: pass if self.weights_dir is None: self.weights_dir = os.path.join(WEIGHTS_DIR, self.run_name) if not os.path.exists(self.weights_dir) and len(self.weights_dir): os.makedirs(self.weights_dir) # save config in output dir copyfile(config_path, os.path.join(self.weights_dir, "config.yaml")) def get_name(self): return self.__name__() def initialize_detector(self): """ Every trainer must have some base detector model """ self.detector = None pass def initialize_meta_detector(self): """ Every trainer must have some base detector model """ self.meta_detector = None pass def check_eval_mode(self): if self.eval_mode != 'aggregated' and self.eval_mode != 'by_word': raise ValueError("No such eval mode!!!") # calculate train loss def calc_loss(self, soft_estimation: torch.Tensor, transmitted_words: torch.Tensor) -> torch.Tensor: """ Every trainer must have some loss calculation """ pass # setup the optimization algorithm def deep_learning_setup(self): """ Sets up the optimizer and loss criterion """ if self.optimizer_type == 'Adam': self.optimizer = Adam(filter(lambda p: p.requires_grad, self.detector.parameters()), lr=self.lr) elif self.optimizer_type == 'RMSprop': self.optimizer = RMSprop(filter(lambda p: p.requires_grad, self.detector.parameters()), lr=self.lr) elif self.optimizer_type == 'SGD': self.optimizer = SGD(filter(lambda p: p.requires_grad, self.detector.parameters()), lr=self.lr) else: raise NotImplementedError("No such optimizer implemented!!!") if self.loss_type == 'BCE': self.criterion = BCELoss().to(device) elif self.loss_type == 'CrossEntropy': self.criterion = CrossEntropyLoss().to(device) elif self.loss_type == 'MSE': self.criterion = MSELoss().to(device) else: raise NotImplementedError("No such loss function implemented!!!") def initialize_dataloaders(self): """ Sets up the data loader - a generator from which we draw batches, in iterations """ self.snr_range = { 'train': np.arange(self.train_SNR_start, self.train_SNR_end + 1, step=self.train_SNR_step), 'val': np.arange(self.val_SNR_start, self.val_SNR_end + 1, step=self.val_SNR_step) } self.frames_per_phase = { 'train': self.train_frames, 'val': self.val_frames } self.block_lengths = { 'train': self.train_block_length, 'val': self.val_block_length } self.channel_coefficients = { 'train': 'time_decay', 'val': self.channel_coefficients } self.transmission_lengths = { 'train': self.train_block_length if not self.use_ecc else self.train_block_length + 8 * self.n_symbols, 'val': self.val_block_length if not self.use_ecc else self.val_block_length + 8 * self.n_symbols } self.channel_dataset = { phase: ChannelModelDataset( channel_type=self.channel_type, block_length=self.block_lengths[phase], transmission_length=self.transmission_lengths[phase], words=self.frames_per_phase[phase] * self.subframes_in_frame, memory_length=self.memory_length, channel_coefficients=self.channel_coefficients[phase], random=self.rand_gen, word_rand_gen=self.word_rand_gen, noisy_est_var=self.noisy_est_var, use_ecc=self.use_ecc, n_symbols=self.n_symbols, fading_taps_type=self.fading_taps_type, fading_in_channel=self.fading_in_channel, fading_in_decoder=self.fading_in_decoder, phase=phase) for phase in ['train', 'val'] } self.dataloaders = { phase: torch.utils.data.DataLoader(self.channel_dataset[phase]) for phase in ['train', 'val'] } def online_training(self, tx: torch.Tensor, rx: torch.Tensor): pass def single_eval_at_point(self, snr: float, gamma: float) -> float: """ Evaluation at a single snr. :param snr: indice of snr in the snrs vector :return: ser for batch """ # draw words of given gamma for all snrs transmitted_words, received_words = self.channel_dataset[ 'val'].__getitem__(snr_list=[snr], gamma=gamma) # decode and calculate accuracy detected_words = self.detector(received_words, 'val', snr, gamma) if self.use_ecc: decoded_words = [ decode(detected_word, self.n_symbols) for detected_word in detected_words.cpu().numpy() ] detected_words = torch.Tensor(decoded_words).to(device) ser, fer, err_indices = calculate_error_rates( detected_words[self.data_indices], transmitted_words[self.data_indices]) return ser def gamma_eval(self, gamma: float) -> np.ndarray: """ Evaluation at a single gamma value. :return: ser for batch. """ ser_total = np.zeros(len(self.snr_range['val'])) for snr_ind, snr in enumerate(self.snr_range['val']): self.load_weights(snr, gamma) ser_total[snr_ind] = self.single_eval_at_point(snr, gamma) return ser_total def evaluate_at_point(self) -> np.ndarray: """ Monte-Carlo simulation over validation SNRs range :return: ber, fer, iterations vectors """ ser_total = np.zeros(len(self.snr_range['val'])) with torch.no_grad(): print(f'Starts evaluation at gamma {self.gamma}') start = time() ser_total += self.gamma_eval(self.gamma) print(f'Done. time: {time() - start}, ser: {ser_total}') return ser_total def eval_by_word(self, snr: float, gamma: float) -> Union[float, np.ndarray]: if self.self_supervised: self.deep_learning_setup() total_ser = 0 # draw words of given gamma for all snrs transmitted_words, received_words = self.channel_dataset[ 'val'].__getitem__(snr_list=[snr], gamma=gamma) ser_by_word = np.zeros(transmitted_words.shape[0]) # saved detector is used to initialize the decoder in meta learning loops self.saved_detector = copy.deepcopy(self.detector) # query for all detected words if self.buffer_empty: buffer_rx = torch.empty([0, received_words.shape[1]]).to(device) buffer_tx = torch.empty([0, received_words.shape[1]]).to(device) buffer_ser = torch.empty([0]).to(device) else: # draw words from different channels buffer_tx, buffer_rx = self.channel_dataset['train'].__getitem__( snr_list=[snr], gamma=gamma) buffer_ser = torch.zeros(buffer_rx.shape[0]).to(device) buffer_tx = torch.cat([ torch.Tensor( encode(transmitted_word.int().cpu().numpy(), self.n_symbols).reshape(1, -1)).to(device) for transmitted_word in buffer_tx ], dim=0) support_idx = torch.arange(-self.window_size - 1, -1).long().to(device) query_idx = -1 * torch.ones(1).long().to(device) for count, (transmitted_word, received_word) in enumerate( zip(transmitted_words, received_words)): transmitted_word, received_word = transmitted_word.reshape( 1, -1), received_word.reshape(1, -1) # detect detected_word = self.detector(received_word, 'val', snr, gamma, count) if count in self.data_indices: # decode decoded_word = [ decode(detected_word, self.n_symbols) for detected_word in detected_word.cpu().numpy() ] decoded_word = torch.Tensor(decoded_word).to(device) # calculate accuracy ser, fer, err_indices = calculate_error_rates( decoded_word, transmitted_word) # encode word again decoded_word_array = decoded_word.int().cpu().numpy() encoded_word = torch.Tensor( encode(decoded_word_array, self.n_symbols).reshape(1, -1)).to(device) errors_num = torch.sum(torch.abs(encoded_word - detected_word)).item() print('*' * 20) print(f'current: {count, ser, errors_num}') total_ser += ser ser_by_word[count] = ser else: print('*' * 20) print(f'current: {count}, Pilot') # encode word again decoded_word_array = transmitted_word.int().cpu().numpy() encoded_word = torch.Tensor( encode(decoded_word_array, self.n_symbols).reshape(1, -1)).to(device) ser = 0 # save the encoded word in the buffer if ser <= self.ser_thresh: buffer_rx = torch.cat([buffer_rx, received_word]) buffer_tx = torch.cat([ buffer_tx, detected_word.reshape(1, -1) if ser > 0 else encoded_word.reshape(1, -1) ], dim=0) buffer_ser = torch.cat( [buffer_ser, torch.FloatTensor([ser]).to(device)]) if not self.buffer_empty: buffer_rx = buffer_rx[1:] buffer_tx = buffer_tx[1:] buffer_ser = buffer_ser[1:] if self.online_meta and count % self.meta_subframes == 0 and count >= self.meta_subframes and \ buffer_rx.shape[0] > 2: # self.subframes_in_frame print('meta-training') self.meta_weights_init() for i in range(self.meta_train_iterations): j_hat_values = torch.unique( torch.randint(low=0, high=buffer_rx.shape[0] - 2, size=[self.meta_j_num])).to(device) for j_hat in j_hat_values: cur_support_idx = j_hat + support_idx + 1 cur_query_idx = j_hat + query_idx + 1 self.meta_train_loop(buffer_rx, buffer_tx, cur_support_idx, cur_query_idx) copy_model(source_model=self.detector, dest_model=self.saved_detector) if self.self_supervised and ser <= self.ser_thresh: # use last word inserted in the buffer for training self.online_training(buffer_tx[-1].reshape(1, -1), buffer_rx[-1].reshape(1, -1)) if (count + 1) % 10 == 0: print( f'Self-supervised: {count + 1}/{transmitted_words.shape[0]}, SER {total_ser / (count + 1)}' ) total_ser /= transmitted_words.shape[0] print(f'Final ser: {total_ser}') return ser_by_word def meta_weights_init(self): if self.weights_init == 'random': self.initialize_detector() self.deep_learning_setup() elif self.weights_init == 'last_frame': copy_model(source_model=self.saved_detector, dest_model=self.detector) elif self.weights_init == 'meta_training': snr = self.snr_range['val'][0] self.load_weights(snr, self.gamma) else: raise ValueError('No such weights init!!!') def evaluate(self) -> np.ndarray: """ Evaluation either happens in a point aggregation way, or in a word-by-word fashion """ # eval with training self.check_eval_mode() if self.eval_mode == 'by_word': if not self.use_ecc: raise ValueError('Only supports ecc') snr = self.snr_range['val'][0] self.load_weights(snr, self.gamma) return self.eval_by_word(snr, self.gamma) else: return self.evaluate_at_point() def meta_train(self): """ Main meta-training loop. Runs in minibatches, each minibatch is split to pairs of following words. The pairs are comprised of (support,query) words. Evaluates performance over validation SNRs. Saves weights every so and so iterations. """ # initialize weights and loss for snr in self.snr_range['train']: print(f'SNR - {snr}, Gamma - {self.gamma}') # initialize weights and loss self.initialize_detector() self.deep_learning_setup() for minibatch in range(1, self.train_minibatch_num + 1): # draw words from different channels transmitted_words, received_words = self.channel_dataset[ 'train'].__getitem__(snr_list=[snr], gamma=self.gamma) support_idx = torch.arange(-self.window_size - 1, -1).long().to(device) query_idx = -1 * torch.ones(1).long().to(device) j_hat_values = torch.unique( torch.randint(low=self.window_size, high=transmitted_words.shape[0], size=[self.meta_j_num])).to(device) if self.use_ecc: transmitted_words = torch.cat([ torch.Tensor( encode(transmitted_word.int().cpu().numpy(), self.n_symbols).reshape(1, -1)).to(device) for transmitted_word in transmitted_words ], dim=0) loss_query = 0 for j_hat in j_hat_values: cur_support_idx = j_hat + support_idx + 1 cur_query_idx = j_hat + query_idx + 1 loss_query += self.meta_train_loop(received_words, transmitted_words, cur_support_idx, cur_query_idx) # evaluate performance ser = self.single_eval_at_point(snr, self.gamma) print( f'Minibatch {minibatch}, ser - {ser}, loss - {loss_query}') # save best weights self.save_weights(float(loss_query), snr, self.gamma) def meta_train_loop(self, received_words: torch.Tensor, transmitted_words: torch.Tensor, support_idx: torch.Tensor, query_idx: torch.Tensor): # divide the words to following pairs - (support,query) support_tx, support_rx = transmitted_words[ support_idx], received_words[support_idx] query_tx, query_rx = transmitted_words[query_idx], received_words[ query_idx] # local update (with support set) para_list_detector = list( map(lambda p: p[0], zip(self.detector.parameters()))) soft_estimation_supp = self.meta_detector(support_rx, 'train', para_list_detector) loss_supp = self.calc_loss(soft_estimation=soft_estimation_supp, transmitted_words=support_tx) # set create_graph to True for MAML, False for FO-MAML local_grad = torch.autograd.grad(loss_supp, para_list_detector, create_graph=self.MAML) updated_para_list_detector = list( map(lambda p: p[1] - self.meta_lr * p[0], zip(local_grad, para_list_detector))) # meta-update (with query set) should be same channel with support set soft_estimation_query = self.meta_detector(query_rx, 'train', updated_para_list_detector) loss_query = self.calc_loss(soft_estimation=soft_estimation_query, transmitted_words=query_tx) meta_grad = torch.autograd.grad(loss_query, para_list_detector, create_graph=False) ind_param = 0 for param in self.detector.parameters(): param.grad = None # zero_grad param.grad = meta_grad[ind_param] ind_param += 1 self.optimizer.step() return loss_query def train(self): """ Main training loop. Runs in minibatches. Evaluates performance over validation SNRs. Saves weights every so and so iterations. """ # batches loop for snr in self.snr_range['train']: print(f'SNR - {snr}, Gamma - {self.gamma}') # initialize weights and loss self.initialize_detector() self.deep_learning_setup() best_ser = math.inf for minibatch in range(1, self.train_minibatch_num + 1): # draw words transmitted_words, received_words = self.channel_dataset[ 'train'].__getitem__(snr_list=[snr], gamma=self.gamma) # run training loops current_loss = 0 for i in range(self.train_frames * self.subframes_in_frame): # pass through detector soft_estimation = self.detector( received_words[i].reshape(1, -1), 'train') current_loss += self.run_train_loop( soft_estimation, transmitted_words[i].reshape(1, -1)) # evaluate performance ser = self.single_eval_at_point(snr, self.gamma) print( f'Minibatch {minibatch}, ser - {ser}, loss {current_loss}') # save best weights if ser < best_ser: self.save_weights(current_loss, snr, self.gamma) best_ser = ser print(f'best ser - {best_ser}') print('*' * 50) def run_train_loop(self, soft_estimation: torch.Tensor, transmitted_words: torch.Tensor): # calculate loss loss = self.calc_loss(soft_estimation=soft_estimation, transmitted_words=transmitted_words) # if loss is Nan inform the user if torch.sum(torch.isnan(loss)): print('Nan value') return np.nan current_loss = loss.item() # back propagation for param in self.detector.parameters(): param.grad = None loss.backward() self.optimizer.step() return current_loss def save_weights(self, current_loss: float, snr: float, gamma: float): torch.save( { 'model_state_dict': self.detector.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': current_loss }, os.path.join(self.weights_dir, f'snr_{snr}_gamma_{gamma}.pt')) def load_weights(self, snr: float, gamma: float): """ Loads detector's weights defined by the [snr,gamma] from checkpoint, if exists """ if os.path.join(self.weights_dir, f'snr_{snr}_gamma_{gamma}.pt'): print(f'loading model from snr {snr} and gamma {gamma}') weights_path = os.path.join(self.weights_dir, f'snr_{snr}_gamma_{gamma}.pt') if not os.path.isfile(weights_path): # if weights do not exist, train on the synthetic channel. Then validate on the test channel. self.fading_taps_type = 1 os.makedirs(self.weights_dir, exist_ok=True) self.train() self.fading_taps_type = 2 checkpoint = torch.load(weights_path) try: self.detector.load_state_dict(checkpoint['model_state_dict']) except Exception: raise ValueError("Wrong run directory!!!") else: print( f'No checkpoint for snr {snr} and gamma {gamma} in run "{self.run_name}", starting from scratch' ) def select_batch( self, gt_examples: torch.LongTensor, soft_estimation: torch.Tensor ) -> Tuple[torch.LongTensor, torch.Tensor]: """ Select a batch from the input and gt labels :param gt_examples: training labels :param soft_estimation: the soft approximation, distribution over states (per word) :return: selected batch from the entire "epoch", contains both labels and the NN soft approximation """ rand_ind = torch.multinomial( torch.arange(gt_examples.shape[0]).float(), self.train_minibatch_size).long().to(device) return gt_examples[rand_ind], soft_estimation[rand_ind]
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, val_iterations, mixed_precision, lr, warmup, milestones, gamma, rank=0, world=1, no_apex=False, use_dali=True, verbose=True, metrics_url=None, logdir=None, rotate_augment=False, augment_brightness=0.0, augment_contrast=0.0, augment_hue=0.0, augment_saturation=0.0, regularization_l2=0.0001, rotated_bbox=False, absolute_angle=False): 'Train the model on the given dataset' # Prepare model nn_model = model stride = model.stride model = convert_fixedbn_model(model) if torch.cuda.is_available(): model = model.to(memory_format=torch.channels_last).cuda() # Setup optimizer and schedule optimizer = SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9) is_master = rank == 0 if not no_apex: loss_scale = "dynamic" if use_dali else "128.0" model, optimizer = amp.initialize( model, optimizer, opt_level='O2' if mixed_precision else 'O0', keep_batchnorm_fp32=True, loss_scale=loss_scale, verbosity=is_master) if world > 1: model = DDP(model, device_ids=[rank]) if no_apex else ADDP(model) model.train() if 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return gamma**len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer, schedule) if 'scheduler' in state: scheduler.load_state_dict(state['scheduler']) # Prepare dataset if verbose: print('Preparing dataset...') if rotated_bbox: if use_dali: raise NotImplementedError( "This repo does not currently support DALI for rotated bbox detections." ) data_iterator = RotatedDataIterator( path, jitter, max_size, batch_size, stride, world, annotations, training=True, rotate_augment=rotate_augment, augment_brightness=augment_brightness, augment_contrast=augment_contrast, augment_hue=augment_hue, augment_saturation=augment_saturation, absolute_angle=absolute_angle) else: data_iterator = (DaliDataIterator if use_dali else DataIterator)( path, jitter, max_size, batch_size, stride, world, annotations, training=True, rotate_augment=rotate_augment, augment_brightness=augment_brightness, augment_contrast=augment_contrast, augment_hue=augment_hue, augment_saturation=augment_saturation) if verbose: print(data_iterator) if verbose: print(' device: {} {}'.format( world, 'cpu' if not torch.cuda.is_available() else 'GPU' if world == 1 else 'GPUs')) print(' batch: {}, precision: {}'.format( batch_size, 'mixed' if mixed_precision else 'full')) print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned') print('Training model for {} iterations...'.format(iterations)) # Create TensorBoard writer if is_master and logdir is not None: from torch.utils.tensorboard import SummaryWriter if verbose: print('Writing TensorBoard logs to: {}'.format(logdir)) writer = SummaryWriter(log_dir=logdir) scaler = GradScaler() profiler = Profiler(['train', 'fw', 'bw']) iteration = state.get('iteration', 0) while iteration < iterations: cls_losses, box_losses = [], [] for i, (data, target) in enumerate(data_iterator): if iteration >= iterations: break # Forward pass profiler.start('fw') optimizer.zero_grad() if not no_apex: cls_loss, box_loss = model([ data.contiguous(memory_format=torch.channels_last), target ]) else: with autocast(): cls_loss, box_loss = model([ data.contiguous(memory_format=torch.channels_last), target ]) del data profiler.stop('fw') # Backward pass profiler.start('bw') if not no_apex: with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() else: scaler.scale(cls_loss + box_loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() # Reduce all losses cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean( ).clone() if world > 1: torch.distributed.all_reduce(cls_loss) torch.distributed.all_reduce(box_loss) cls_loss /= world box_loss /= world if is_master: cls_losses.append(cls_loss) box_losses.append(box_loss) if is_master and not isfinite(cls_loss + box_loss): raise RuntimeError('Loss is diverging!\n{}'.format( 'Try lowering the learning rate.')) del cls_loss, box_loss profiler.stop('bw') iteration += 1 profiler.bump('train') if is_master and (profiler.totals['train'] > 60 or iteration == iterations): focal_loss = torch.stack(list(cls_losses)).mean().item() box_loss = torch.stack(list(box_losses)).mean().item() learning_rate = optimizer.param_groups[0]['lr'] if verbose: msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) msg += ' focal loss: {:.3f}'.format(focal_loss) msg += ', box loss: {:.3f}'.format(box_loss) msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format( profiler.means['fw'], profiler.means['bw']) msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) msg += ', lr: {:.2g}'.format(learning_rate) print(msg, flush=True) if is_master and logdir is not None: writer.add_scalar('focal_loss', focal_loss, iteration) writer.add_scalar('box_loss', box_loss, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) del box_loss, focal_loss if metrics_url: post_metrics( metrics_url, { 'focal loss': mean(cls_losses), 'box loss': mean(box_losses), 'im_s': batch_size / profiler.means['train'], 'lr': learning_rate }) # Save model weights state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }) with ignore_sigint(): nn_model.save(state) profiler.reset() del cls_losses[:], box_losses[:] if val_annotations and (iteration == iterations or iteration % val_iterations == 0): stats = infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, no_apex=no_apex, is_validation=True, verbose=False, rotated_bbox=rotated_bbox) model.train() if is_master and logdir is not None and stats is not None: writer.add_scalar('Validation_Precision/mAP', stats[0], iteration) writer.add_scalar('Validation_Precision/[email protected]', stats[1], iteration) writer.add_scalar('Validation_Precision/[email protected]', stats[2], iteration) writer.add_scalar('Validation_Precision/mAP (small)', stats[3], iteration) writer.add_scalar('Validation_Precision/mAP (medium)', stats[4], iteration) writer.add_scalar('Validation_Precision/mAP (large)', stats[5], iteration) writer.add_scalar('Validation_Recall/mAR (max 1 Dets)', stats[6], iteration) writer.add_scalar('Validation_Recall/mAR (max 10 Dets)', stats[7], iteration) writer.add_scalar('Validation_Recall/mAR (max 100 Dets)', stats[8], iteration) writer.add_scalar('Validation_Recall/mAR (small)', stats[9], iteration) writer.add_scalar('Validation_Recall/mAR (medium)', stats[10], iteration) writer.add_scalar('Validation_Recall/mAR (large)', stats[11], iteration) if (iteration == iterations and not rotated_bbox) or (iteration > iterations and rotated_bbox): break if is_master and logdir is not None: writer.close()
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset( root=args.source_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.ColorJitter(brightness=0.3, contrast=0.3), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset( root=args.target_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=(2., 2.), scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset( root=args.target_root, split='val', transforms=T.Compose([ T.Resize(image_size=args.test_input_size, label_size=args.test_output_size), T.NormalizeAndTranspose(), ]), ) val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model num_classes = train_source_dataset.num_classes model = models.__dict__[args.arch](num_classes=num_classes).to(device) discriminator = Discriminator(num_classes=num_classes).to(device) # define optimizer and lr scheduler optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_d = Adam(discriminator.parameters(), lr=args.lr_d, betas=(0.9, 0.99)) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power)) lr_scheduler_d = LambdaLR( optimizer_d, lambda x: (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power)) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) discriminator.load_state_dict(checkpoint['discriminator']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) optimizer_d.load_state_dict(checkpoint['optimizer_d']) lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d']) args.start_epoch = checkpoint['epoch'] + 1 # define loss function (criterion) criterion = torch.nn.CrossEntropyLoss( ignore_index=args.ignore_label).to(device) dann = DomainAdversarialEntropyLoss(discriminator) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True) # define visualization function decode = train_source_dataset.decode_target def visualize(image, pred, label, prefix): """ Args: image (tensor): 3 x H x W pred (tensor): C x H x W label (tensor): H x W prefix: prefix of the saving image """ image = image.detach().cpu().numpy() pred = pred.detach().max(dim=0)[1].cpu().numpy() label = label.cpu().numpy() for tensor, name in [ (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"), (decode(label), "label"), (decode(pred), "pred") ]: tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name))) if args.phase == 'test': confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args) print(confmat) return # start training best_iou = 0. for epoch in range(args.start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, interp_train, criterion, dann, optimizer, lr_scheduler, optimizer_d, lr_scheduler_d, epoch, visualize if args.debug else None, args) # evaluate on validation set confmat = validate(val_target_loader, model, interp_val, criterion, None, args) print(confmat.format(train_source_dataset.classes)) acc_global, acc, iu = confmat.compute() # calculate the mean iou over partial classes indexes = [ train_source_dataset.classes.index(name) for name in train_source_dataset.evaluate_classes ] iu = iu[indexes] mean_iou = iu.mean() # remember best acc@1 and save checkpoint torch.save( { 'model': model.state_dict(), 'discriminator': discriminator.state_dict(), 'optimizer': optimizer.state_dict(), 'optimizer_d': optimizer_d.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler_d': lr_scheduler_d.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if mean_iou > best_iou: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_iou = max(best_iou, mean_iou) print("Target: {} Best: {}".format(mean_iou, best_iou)) logger.close()
def main_worker(gpu, ngpus_per_node, args): global best_acc args.gpu = gpu assert args.gpu is not None print("Use GPU: {} for training".format(args.gpu)) log = open( os.path.join( args.save_path, 'log_seed{}{}.txt'.format(args.manualSeed, '_eval' if args.evaluate else '')), 'w') log = (log, args.gpu) net = models.__dict__[args.arch](pretrained=True) disable_dropout(net) net = to_bayesian(net, args.psi_init_range) net.apply(unfreeze) print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("PyTorch version : {}".format(torch.__version__), log) print_log("CuDNN version : {}".format(torch.backends.cudnn.version()), log) print_log( "Number of parameters: {}".format( sum([p.numel() for p in net.parameters()])), log) print_log(str(args), log) if args.distributed: if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url + ":" + args.dist_port, world_size=args.world_size, rank=args.rank) torch.cuda.set_device(args.gpu) net.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu]) else: torch.cuda.set_device(args.gpu) net = net.cuda(args.gpu) criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu) mus, psis = [], [] for name, param in net.named_parameters(): if 'psi' in name: psis.append(param) else: mus.append(param) mu_optimizer = SGD(mus, args.learning_rate, args.momentum, weight_decay=args.decay, nesterov=(args.momentum > 0.0)) psi_optimizer = PsiSGD(psis, args.learning_rate, args.momentum, weight_decay=args.decay, nesterov=(args.momentum > 0.0)) recorder = RecorderMeter(args.epochs) if args.resume: if args.resume == 'auto': args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar') if os.path.isfile(args.resume): print_log("=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume, map_location='cuda:{}'.format(args.gpu)) recorder = checkpoint['recorder'] recorder.refresh(args.epochs) args.start_epoch = checkpoint['epoch'] net.load_state_dict( checkpoint['state_dict'] if args.distributed else { k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() }) mu_optimizer.load_state_dict(checkpoint['mu_optimizer']) psi_optimizer.load_state_dict(checkpoint['psi_optimizer']) best_acc = recorder.max_accuracy(False) print_log( "=> loaded checkpoint '{}' accuracy={} (epoch {})".format( args.resume, best_acc, checkpoint['epoch']), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log("=> do not use any checkpoint for the model", log) cudnn.benchmark = True train_loader, ood_train_loader, test_loader, adv_loader, \ fake_loader, adv_loader2 = load_dataset_ft(args) psi_optimizer.num_data = len(train_loader.dataset) if args.evaluate: evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion, args, log, 20, 100) return start_time = time.time() epoch_time = AverageMeter() train_los = -1 for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_loader.sampler.set_epoch(epoch) ood_train_loader.sampler.set_epoch(epoch) cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer, epoch, args) need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format( time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \ + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) train_acc, train_los = train(train_loader, ood_train_loader, net, criterion, mu_optimizer, psi_optimizer, epoch, args, log) val_acc, val_los = 0, 0 recorder.update(epoch, train_los, train_acc, val_acc, val_los) is_best = False if val_acc > best_acc: is_best = True best_acc = val_acc if args.gpu == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'recorder': recorder, 'mu_optimizer': mu_optimizer.state_dict(), 'psi_optimizer': psi_optimizer.state_dict(), }, False, args.save_path, 'checkpoint.pth.tar') epoch_time.update(time.time() - start_time) start_time = time.time() recorder.plot_curve(os.path.join(args.save_path, 'log.png')) evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion, args, log, 20, 100) log[0].close()
triplet_loss = TripletLoss(margin=margin).forward( anchor=anc_hard_embedding, positive=pos_hard_embedding, negative=neg_hard_embedding).cuda() triplet_loss_sum += triplet_loss.item() num_valid_training_triplets += len(anc_hard_embedding) optimizer_model.zero_grad() triplet_loss.backward() optimizer_model.step() avg_triplet_loss = 0 if ( num_valid_training_triplets == 0) else triplet_loss_sum / num_valid_training_triplets print( 'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}' .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets)) torch.save( { 'epoch': epoch, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer_model.state_dict(), 'avg_triplet_loss': avg_triplet_loss, 'valid_training_triplets': num_valid_training_triplets }, './train_checkpoints/' + 'checkpoint_' + str(total_triplets) + '_' + str(epoch) + '_' + str(num_valid_training_triplets) + '.tar')
def main(input_len, epochs_num, hidden_size, batch_size, output_size, lr): start = datetime.datetime(1999, 1, 8) end = datetime.datetime(2016, 12, 31) #test_start = datetime.datetime(2015, 1, 8) #test_end = datetime.datetime(2016, 12, 31) training_size = 0 test_size = 0 train_x, train_t, test_x, test_t = mkDataSet(start, end, input_len) model = Predictor(6, hidden_size, output_size) test_x = torch.Tensor(test_x) test_t = torch.Tensor(test_t) train_x = torch.Tensor(train_x) train_t = torch.Tensor(train_t) #print(test_x.size()) #print(test_t.size()) #print(test_x) #print(test_t) #exit #test_x = torch.Tensor(test_x) #test_t = torch.Tensor(test_t) dataset = TensorDataset(train_x, train_t) loader_train = DataLoader(dataset, batch_size=batch_size, shuffle=True) dataset = TensorDataset(test_x, test_t) loader_test = DataLoader(dataset, batch_size=batch_size, shuffle=False) #dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=4, shuffle=True,num_workers=2) #torch.backends.cudnn.benchmark=True optimizer = SGD(model.parameters(), lr) criterion = torch.nn.BCELoss(size_average=False) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[epochs_num * 0.3, epochs_num * 0.7], gamma=0.1, last_epoch=-1) loss_record = [] count = 0 for epoch in range(epochs_num): # training running_loss = 0.0 training_accuracy = 0.0 training_num = 0 #scheduler.step() model.train() for i, data in enumerate(loader_train, 0): #入力データ・ラベルに分割 # get the inputs inputs, labels = data # optimizerの初期化 # zero the parameter gradients optimizer.zero_grad() #一連の流れ # forward + backward + optimize outputs = model(inputs) labels = labels.float() #ここでラベルデータに対するCross-Entropyがとられる loss = criterion(outputs, labels) loss.backward() optimizer.step() # ロスの表示 # print statistics #running_loss += loss.data[0] running_loss += loss.data.item() * 100 training_accuracy += np.sum( np.abs((outputs.data - labels.data).numpy()) <= 0.5) training_num += np.sum( np.abs((outputs.data - labels.data).numpy()) != 10000) #test test_accuracy = 0.0 test_num = 0 model.eval() for i, data in enumerate(loader_test, 0): inputs, labels = data outputs = model(inputs) labels = labels.float() #print("#######################") #print(outputs) #print(labels) #print(output.t_(),label.t_()) #print(np.abs((output.data - label.data).numpy())) test_accuracy += np.sum( np.abs((outputs.data - labels.data).numpy()) <= 0.5) test_num += np.sum( np.abs((outputs.data - labels.data).numpy()) != 100000) training_accuracy /= training_num test_accuracy /= test_num if ((epoch + 1) % 1 == 0): print( '%d loss: %.3f, training_accuracy: %.5f, test_accuracy: %.5f' % (epoch + 1, running_loss, training_accuracy, test_accuracy)) #print(output) torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, 'jikkenpath') #loss_record.append(running_loss) else: print(training_num) print(test_num) #print(loss_record) if (1): test_x, test_t = mkTestModelset(input_len) test_x = torch.Tensor(test_x) test_t = torch.Tensor(test_t) dataset = TensorDataset(test_x, test_t) loader_test = DataLoader(dataset, batch_size=5, shuffle=False) test_accuracy = 0.0 test_num = 0 av_testac = 0 model.eval() for i, data in enumerate(loader_test, 0): inputs, labels = data outputs = model(inputs) labels = labels.float() test_accuracy = 0.0 test_num = 0 test_accuracy += np.sum( np.abs((outputs.data - labels.data).numpy()) <= 0.5) test_num += np.sum( np.abs((outputs.data - labels.data).numpy()) != 100000) test_accuracy /= test_num av_testac += test_accuracy print(i, test_accuracy) else: print(av_testac / (i + 1)) torch.save(model.state_dict(), 'weight.pth')
def train(train_dir, model_dir, config_path, checkpoint_path, n_steps, save_every, test_every, decay_every, n_speakers, n_utterances, seg_len): """Train a d-vector network.""" # setup total_steps = 0 # load data dataset = SEDataset(train_dir, n_utterances, seg_len) train_set, valid_set = random_split(dataset, [len(dataset)-2*n_speakers, 2*n_speakers]) train_loader = DataLoader(train_set, batch_size=n_speakers, shuffle=True, num_workers=4, collate_fn=pad_batch, drop_last=True) valid_loader = DataLoader(valid_set, batch_size=n_speakers, shuffle=True, num_workers=4, collate_fn=pad_batch, drop_last=True) train_iter = iter(train_loader) assert len(train_set) >= n_speakers assert len(valid_set) >= n_speakers print(f"Training starts with {len(train_set)} speakers. " f"(and {len(valid_set)} speakers for validation)") # build network and training tools dvector = DVector().load_config_file(config_path) criterion = GE2ELoss() optimizer = SGD(list(dvector.parameters()) + list(criterion.parameters()), lr=0.01) scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5) # load checkpoint if checkpoint_path is not None: ckpt = torch.load(checkpoint_path) total_steps = ckpt["total_steps"] dvector.load_state_dict(ckpt["state_dict"]) criterion.load_state_dict(ckpt["criterion"]) optimizer.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) # prepare for training device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dvector = dvector.to(device) criterion = criterion.to(device) writer = SummaryWriter(model_dir) pbar = tqdm.trange(n_steps) # start training for step in pbar: total_steps += 1 try: batch = next(train_iter) except StopIteration: train_iter = iter(train_loader) batch = next(train_iter) embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1) loss = criterion(embd) optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( list(dvector.parameters()) + list(criterion.parameters()), max_norm=3) dvector.embedding.weight.grad.data *= 0.5 criterion.w.grad.data *= 0.01 criterion.b.grad.data *= 0.01 optimizer.step() scheduler.step() pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}") writer.add_scalar("Training loss", loss, total_steps) writer.add_scalar("Gradient norm", grad_norm, total_steps) if (step + 1) % test_every == 0: batch = next(iter(valid_loader)) embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1) loss = criterion(embd) writer.add_scalar("validation loss", loss, total_steps) if (step + 1) % save_every == 0: ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar") ckpt_dict = { "total_steps": total_steps, "state_dict": dvector.state_dict(), "criterion": criterion.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), } torch.save(ckpt_dict, ckpt_path) print("Training completed.")
class LightHeadRCNN_Learner(Module): def __init__(self, training=True): super(LightHeadRCNN_Learner, self).__init__() self.conf = Config() self.class_2_color = get_class_colors(self.conf) self.extractor = ResNet101Extractor(self.conf.pretrained_model_path).to(self.conf.device) self.rpn = RegionProposalNetwork().to(self.conf.device) # self.head = LightHeadRCNNResNet101_Head(self.conf.class_num + 1, self.conf.roi_size).to(self.conf.device) self.loc_normalize_mean=(0., 0., 0., 0.), self.loc_normalize_std=(0.1, 0.1, 0.2, 0.2) self.head = LightHeadRCNNResNet101_Head(self.conf.class_num + 1, self.conf.roi_size, roi_align = self.conf.use_roi_align).to(self.conf.device) self.class_2_color = get_class_colors(self.conf) self.detections = namedtuple('detections', ['roi_cls_locs', 'roi_scores', 'rois']) if training: self.train_dataset = coco_dataset(self.conf, mode = 'train') self.train_length = len(self.train_dataset) self.val_dataset = coco_dataset(self.conf, mode = 'val') self.val_length = len(self.val_dataset) self.anchor_target_creator = AnchorTargetCreator() self.proposal_target_creator = ProposalTargetCreator(loc_normalize_mean = self.loc_normalize_mean, loc_normalize_std = self.loc_normalize_std) self.step = 0 self.optimizer = SGD([ {'params' : get_trainables(self.extractor.parameters())}, {'params' : self.rpn.parameters()}, {'params' : [*self.head.parameters()][:8], 'lr' : self.conf.lr*3}, {'params' : [*self.head.parameters()][8:]}, ], lr = self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay) self.base_lrs = [params['lr'] for params in self.optimizer.param_groups] self.warm_up_duration = 5000 self.warm_up_rate = 1 / 5 self.train_outputs = namedtuple('train_outputs', ['loss_total', 'rpn_loc_loss', 'rpn_cls_loss', 'ohem_roi_loc_loss', 'ohem_roi_cls_loss', 'total_roi_loc_loss', 'total_roi_cls_loss']) self.writer = SummaryWriter(self.conf.log_path) self.board_loss_every = self.train_length // self.conf.board_loss_interval self.evaluate_every = self.train_length // self.conf.eval_interval self.eva_on_coco_every = self.train_length // self.conf.eval_coco_interval self.board_pred_image_every = self.train_length // self.conf.board_pred_image_interval self.save_every = self.train_length // self.conf.save_interval # only for debugging # self.board_loss_every = 5 # self.evaluate_every = 6 # self.eva_on_coco_every = 7 # self.board_pred_image_every = 8 # self.save_every = 10 def set_training(self): self.train() self.extractor.set_bn_eval() def lr_warmup(self): assert self.step <= self.warm_up_duration, 'stop warm up after {} steps'.format(self.warm_up_duration) rate = self.warm_up_rate + (1 - self.warm_up_rate) * self.step / self.warm_up_duration for i, params in enumerate(self.optimizer.param_groups): params['lr'] = self.base_lrs[i] * rate def lr_schedule(self, epoch): if epoch < 13: return elif epoch < 16: rate = 0.1 else: rate = 0.01 for i, params in enumerate(self.optimizer.param_groups): params['lr'] = self.base_lrs[i] * rate print(self.optimizer) def forward(self, img_tensor, scale, bboxes=None, labels=None, force_eval=False): img_tensor = img_tensor.to(self.conf.device) img_size = (img_tensor.shape[2], img_tensor.shape[3]) # H,W rpn_feature, roi_feature = self.extractor(img_tensor) rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(rpn_feature, img_size, scale) if self.training or force_eval: gt_rpn_loc, gt_rpn_labels = self.anchor_target_creator(bboxes, anchor, img_size) gt_rpn_labels = torch.tensor(gt_rpn_labels, dtype=torch.long).to(self.conf.device) if len(bboxes) == 0: rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_rpn_labels, ignore_index = -1) return self.train_outputs(rpn_cls_loss, 0, 0, 0, 0, 0, 0) sample_roi, gt_roi_locs, gt_roi_labels = self.proposal_target_creator(rois, bboxes, labels) roi_cls_locs, roi_scores = self.head(roi_feature, sample_roi) # roi_cls_locs, roi_scores, pool, h, rois = self.head(roi_feature, sample_roi) gt_rpn_loc = torch.tensor(gt_rpn_loc, dtype=torch.float).to(self.conf.device) gt_roi_locs = torch.tensor(gt_roi_locs, dtype=torch.float).to(self.conf.device) gt_roi_labels = torch.tensor(gt_roi_labels, dtype=torch.long).to(self.conf.device) rpn_loc_loss = fast_rcnn_loc_loss(rpn_locs[0], gt_rpn_loc, gt_rpn_labels, sigma=self.conf.rpn_sigma) rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_rpn_labels, ignore_index = -1) ohem_roi_loc_loss, \ ohem_roi_cls_loss, \ total_roi_loc_loss, \ total_roi_cls_loss = OHEM_loss(roi_cls_locs, roi_scores, gt_roi_locs, gt_roi_labels, self.conf.n_ohem_sample, self.conf.roi_sigma) loss_total = rpn_loc_loss + rpn_cls_loss + ohem_roi_loc_loss + ohem_roi_cls_loss # if loss_total.item() > 1000.: # print('ohem_roi_loc_loss : {}, ohem_roi_cls_loss : {}'.format(ohem_roi_loc_loss, ohem_roi_cls_loss)) # torch.save(pool, 'pool_debug.pth') # torch.save(h, 'h_debug.pth') # np.save('rois_debug', rois) # torch.save(roi_cls_locs, 'roi_cls_locs_debug.pth') # torch.save(roi_scores, 'roi_scores_debug.pth') # torch.save(gt_roi_locs, 'gt_roi_locs_debug.pth') # torch.save(gt_roi_labels, 'gt_roi_labels_debug.pth') # pdb.set_trace() return self.train_outputs(loss_total, rpn_loc_loss.item(), rpn_cls_loss.item(), ohem_roi_loc_loss.item(), ohem_roi_cls_loss.item(), total_roi_loc_loss, total_roi_cls_loss) else: roi_cls_locs, roi_scores = self.head(roi_feature, rois) return self.detections(roi_cls_locs, roi_scores, rois) def eval_predict(self, img, preset = 'evaluate', use_softnms = False): if type(img) == list: img = img[0] img = Image.fromarray(img.transpose(1,2,0).astype('uint8')) bboxes, labels, scores = self.predict_on_img(img, preset, use_softnms, original_size = True) bboxes = y1x1y2x2_2_x1y1x2y2(bboxes) return [bboxes], [labels], [scores] def predict_on_img(self, img, preset = 'evaluate', use_softnms=False, return_img = False, with_scores = False, original_size = False): ''' inputs : imgs : PIL Image return : PIL Image (if return_img) or bboxes_group and labels_group ''' self.eval() self.use_preset(preset) with torch.no_grad(): orig_size = img.size # W,H img = np.asarray(img).transpose(2,0,1) img, scale = prepare_img(self.conf, img, -1) img = torch.tensor(img).unsqueeze(0) img_size = (img.shape[2], img.shape[3]) # H,W detections = self.forward(img, scale) n_sample = len(detections.roi_cls_locs) n_class = self.conf.class_num + 1 roi_cls_locs = detections.roi_cls_locs.reshape((n_sample, -1, 4)).reshape([-1,4]) roi_cls_locs = roi_cls_locs * torch.tensor(self.loc_normalize_std, device=self.conf.device) + torch.tensor(self.loc_normalize_mean, device=self.conf.device) rois = torch.tensor(detections.rois.repeat(n_class,0), dtype=torch.float).to(self.conf.device) raw_cls_bboxes = loc2bbox(rois, roi_cls_locs) torch.clamp(raw_cls_bboxes[:,0::2], 0, img_size[1], out = raw_cls_bboxes[:,0::2] ) torch.clamp(raw_cls_bboxes[:,1::2], 0, img_size[0], out = raw_cls_bboxes[:,1::2] ) raw_cls_bboxes = raw_cls_bboxes.reshape([n_sample, n_class, 4]) raw_prob = F.softmax(detections.roi_scores, dim=1) bboxes, labels, scores = self._suppress(raw_cls_bboxes, raw_prob, use_softnms) if len(bboxes) == len(labels) == len(scores) == 0: if not return_img: return [], [], [] else: return to_img(self.conf, img[0]) _, indices = scores.sort(descending=True) bboxes = bboxes[indices] labels = labels[indices] scores = scores[indices] if len(bboxes) > self.max_n_predict: bboxes = bboxes[:self.max_n_predict] labels = labels[:self.max_n_predict] scores = scores[:self.max_n_predict] # now, implement drawing bboxes = bboxes.cpu().numpy() labels = labels.cpu().numpy() scores = scores.cpu().numpy() if original_size: bboxes = adjust_bbox(scale, bboxes, detect=True) if not return_img: return bboxes, labels, scores else: if with_scores: scores_ = scores else: scores_ = [] predicted_img = to_img(self.conf, img[0]) if original_size: predicted_img = predicted_img.resize(orig_size) if len(bboxes) != 0 and len(labels) != 0: predicted_img = draw_bbox_class(self.conf, predicted_img, labels, bboxes, self.conf.correct_id_2_class, self.class_2_color, scores = scores_) return predicted_img def _suppress(self, raw_cls_bboxes, raw_prob, use_softnms): bbox = [] label = [] prob = [] for l in range(1, self.conf.class_num + 1): cls_bbox_l = raw_cls_bboxes[:, l, :] prob_l = raw_prob[:, l] mask = prob_l > self.score_thresh if not mask.any(): continue cls_bbox_l = cls_bbox_l[mask] prob_l = prob_l[mask] if use_softnms: keep, _ = soft_nms(torch.cat((cls_bbox_l, prob_l.unsqueeze(-1)), dim=1).cpu().numpy(), Nt = self.conf.softnms_Nt, method = self.conf.softnms_method, sigma = self.conf.softnms_sigma, min_score = self.conf.softnms_min_score) keep = keep.tolist() else: # prob_l, order = torch.sort(prob_l, descending=True) # cls_bbox_l = cls_bbox_l[order] keep = nms(torch.cat((cls_bbox_l, prob_l.unsqueeze(-1)), dim=1), self.nms_thresh).tolist() bbox.append(cls_bbox_l[keep]) # The labels are in [0, 79]. label.append((l - 1) * torch.ones((len(keep),), dtype = torch.long)) prob.append(prob_l[keep]) if len(bbox) == 0: print("looks like there is no prediction have a prob larger than thresh") return [], [], [] bbox = torch.cat(bbox) label = torch.cat(label) prob = torch.cat(prob) return bbox, label, prob def board_scalars(self, key, loss_total, rpn_loc_loss, rpn_cls_loss, ohem_roi_loc_loss, ohem_roi_cls_loss, total_roi_loc_loss, total_roi_cls_loss): self.writer.add_scalar('{}_loss_total'.format(key), loss_total, self.step) self.writer.add_scalar('{}_rpn_loc_loss'.format(key), rpn_loc_loss, self.step) self.writer.add_scalar('{}_rpn_cls_loss'.format(key), rpn_cls_loss, self.step) self.writer.add_scalar('{}_ohem_roi_loc_loss'.format(key), ohem_roi_loc_loss, self.step) self.writer.add_scalar('{}_ohem_roi_cls_loss'.format(key), ohem_roi_cls_loss, self.step) self.writer.add_scalar('{}_total_roi_loc_loss'.format(key), total_roi_loc_loss, self.step) self.writer.add_scalar('{}_total_roi_cls_loss'.format(key), total_roi_cls_loss, self.step) def use_preset(self, preset): """Use the given preset during prediction. This method changes values of :obj:`self.nms_thresh` and :obj:`self.score_thresh`. These values are a threshold value used for non maximum suppression and a threshold value to discard low confidence proposals in :meth:`predict`, respectively. If the attributes need to be changed to something other than the values provided in the presets, please modify them by directly accessing the public attributes. Args: preset ({'visualize', 'evaluate', 'debug'): A string to determine the preset to use. """ if preset == 'visualize': self.nms_thresh = 0.5 self.score_thresh = 0.25 self.max_n_predict = 40 elif preset == 'evaluate': self.nms_thresh = 0.5 self.score_thresh = 0.0 self.max_n_predict = 100 # """ # We finally replace origi-nal 0.3 threshold with 0.5 for Non-maximum Suppression # (NMS). It improves 0.6 points of mmAP by improving the # recall rate especially for the crowd cases. # """ elif preset == 'debug': self.nms_thresh = 0.5 self.score_thresh = 0.0 self.max_n_predict = 10 else: raise ValueError('preset must be visualize or evaluate') def fit(self, epochs=20, resume=False, from_save_folder=False): if resume: self.resume_training_load(from_save_folder) self.set_training() running_loss = 0. running_rpn_loc_loss = 0. running_rpn_cls_loss = 0. running_ohem_roi_loc_loss = 0. running_ohem_roi_cls_loss = 0. running_total_roi_loc_loss = 0. running_total_roi_cls_loss = 0. map05 = None val_loss = None epoch = self.step // self.train_length while epoch <= epochs: print('start the training of epoch : {}'.format(epoch)) self.lr_schedule(epoch) # for index in tqdm(np.random.permutation(self.train_length), total = self.train_length): for index in tqdm(range(self.train_length), total = self.train_length): try: inputs = self.train_dataset[index] except: print('loading index {} from train dataset failed}'.format(index)) # print(self.train_dataset.orig_dataset._datasets[0].id_to_prop[self.train_dataset.orig_dataset._datasets[0].ids[index]]) continue self.optimizer.zero_grad() train_outputs = self.forward(torch.tensor(inputs.img).unsqueeze(0), inputs.scale, inputs.bboxes, inputs.labels) train_outputs.loss_total.backward() if epoch == 0: if self.step <= self.warm_up_duration: self.lr_warmup() self.optimizer.step() torch.cuda.empty_cache() running_loss += train_outputs.loss_total.item() running_rpn_loc_loss += train_outputs.rpn_loc_loss running_rpn_cls_loss += train_outputs.rpn_cls_loss running_ohem_roi_loc_loss += train_outputs.ohem_roi_loc_loss running_ohem_roi_cls_loss += train_outputs.ohem_roi_cls_loss running_total_roi_loc_loss += train_outputs.total_roi_loc_loss running_total_roi_cls_loss += train_outputs.total_roi_cls_loss if self.step != 0: if self.step % self.board_loss_every == 0: self.board_scalars('train', running_loss / self.board_loss_every, running_rpn_loc_loss / self.board_loss_every, running_rpn_cls_loss / self.board_loss_every, running_ohem_roi_loc_loss / self.board_loss_every, running_ohem_roi_cls_loss / self.board_loss_every, running_total_roi_loc_loss / self.board_loss_every, running_total_roi_cls_loss / self.board_loss_every) running_loss = 0. running_rpn_loc_loss = 0. running_rpn_cls_loss = 0. running_ohem_roi_loc_loss = 0. running_ohem_roi_cls_loss = 0. running_total_roi_loc_loss = 0. running_total_roi_cls_loss = 0. if self.step % self.evaluate_every == 0: val_loss, val_rpn_loc_loss, \ val_rpn_cls_loss, \ ohem_val_roi_loc_loss, \ ohem_val_roi_cls_loss, \ total_val_roi_loc_loss, \ total_val_roi_cls_loss = self.evaluate(num = self.conf.eva_num_during_training) self.set_training() self.board_scalars('val', val_loss, val_rpn_loc_loss, val_rpn_cls_loss, ohem_val_roi_loc_loss, ohem_val_roi_cls_loss, total_val_roi_loc_loss, total_val_roi_cls_loss) if self.step % self.eva_on_coco_every == 0: try: cocoEval = self.eva_on_coco(limit = self.conf.coco_eva_num_during_training) self.set_training() map05 = cocoEval[1] mmap = cocoEval[0] except: print('eval on coco failed') map05 = -1 mmap = -1 self.writer.add_scalar('0.5IoU MAP', map05, self.step) self.writer.add_scalar('0.5::0.9 - MMAP', mmap, self.step) if self.step % self.board_pred_image_every == 0: for i in range(20): img, _, _, _ , _= self.val_dataset.orig_dataset[i] img = Image.fromarray(img.astype('uint8').transpose(1,2,0)) predicted_img = self.predict_on_img(img, preset='visualize', return_img=True, with_scores=True, original_size=True) # if type(predicted_img) == tuple: # self.writer.add_image('pred_image_{}'.format(i), trans.ToTensor()(img), global_step=self.step) # else: ## should be deleted after test self.writer.add_image('pred_image_{}'.format(i), trans.ToTensor()(predicted_img), global_step=self.step) self.set_training() if self.step % self.save_every == 0: try: self.save_state(val_loss, map05) except: print('save state failed') self.step += 1 continue self.step += 1 epoch = self.step // self.train_length try: self.save_state(val_loss, map05, to_save_folder=True) except: print('save state failed') def eva_on_coco(self, limit = 1000, preset = 'evaluate', use_softnms = False): self.eval() return eva_coco(self.val_dataset.orig_dataset, lambda x : self.eval_predict(x, preset, use_softnms), limit, preset) def evaluate(self, num=None): self.eval() running_loss = 0. running_rpn_loc_loss = 0. running_rpn_cls_loss = 0. running_ohem_roi_loc_loss = 0. running_ohem_roi_cls_loss = 0. running_total_roi_loc_loss = 0. running_total_roi_cls_loss = 0. if num == None: total_num = self.val_length else: total_num = num with torch.no_grad(): for index in tqdm(range(total_num)): inputs = self.val_dataset[index] if inputs.bboxes == []: continue val_outputs = self.forward(torch.tensor(inputs.img).unsqueeze(0), inputs.scale, inputs.bboxes, inputs.labels, force_eval = True) running_loss += val_outputs.loss_total.item() running_rpn_loc_loss += val_outputs.rpn_loc_loss running_rpn_cls_loss += val_outputs.rpn_cls_loss running_ohem_roi_loc_loss += val_outputs.ohem_roi_loc_loss running_ohem_roi_cls_loss += val_outputs.ohem_roi_cls_loss running_total_roi_loc_loss += val_outputs.total_roi_loc_loss running_total_roi_cls_loss += val_outputs.total_roi_cls_loss return running_loss / total_num, \ running_rpn_loc_loss / total_num, \ running_rpn_cls_loss / total_num, \ running_ohem_roi_loc_loss / total_num, \ running_ohem_roi_cls_loss / total_num,\ running_total_roi_loc_loss / total_num, \ running_total_roi_cls_loss / total_num def save_state(self, val_loss, map05, to_save_folder=False, model_only=False): if to_save_folder: save_path = self.conf.work_space/'save' else: save_path = self.conf.work_space/'model' time = get_time() torch.save( self.state_dict(), save_path / ('model_{}_val_loss:{}_map05:{}_step:{}.pth'.format(time, val_loss, map05, self.step))) if not model_only: torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_val_loss:{}_map05:{}_step:{}.pth'.format(time, val_loss, map05, self.step))) def load_state(self, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = self.conf.work_space/'save' else: save_path = self.conf.work_space/'model' self.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str))) print('load model_{}'.format(fixed_str)) if not model_only: self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str))) print('load optimizer_{}'.format(fixed_str)) def resume_training_load(self, from_save_folder=False): if from_save_folder: save_path = self.conf.work_space/'save' else: save_path = self.conf.work_space/'model' sorted_files = sorted([*save_path.iterdir()], key=lambda x: os.path.getmtime(x), reverse=True) seeking_flag = True index = 0 while seeking_flag: if index > len(sorted_files) - 2: break file_a = sorted_files[index] file_b = sorted_files[index + 1] if file_a.name.startswith('model'): fix_str = file_a.name[6:] self.step = int(fix_str.split(':')[-1].split('.')[0]) + 1 if file_b.name == ''.join(['optimizer', '_', fix_str]): self.load_state(fix_str, from_save_folder) return else: index += 1 continue elif file_a.name.startswith('optimizer'): fix_str = file_a.name[10:] self.step = int(fix_str.split(':')[-1].split('.')[0]) + 1 if file_b.name == ''.join(['model', '_', fix_str]): self.load_state(fix_str, from_save_folder) return else: index += 1 continue else: index += 1 continue print('no available files founded') return
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, val_iterations, mixed_precision, lr, warmup, milestones, gamma, is_master=True, world=1, use_dali=True, verbose=True, metrics_url=None, logdir=None): 'Train the model on the given dataset' # Prepare model nn_model = model stride = model.stride model = convert_fixedbn_model(model) if torch.cuda.is_available(): model = model.cuda() # Setup optimizer and schedule optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9) model, optimizer = amp.initialize( model, optimizer, opt_level='O2' if mixed_precision else 'O0', keep_batchnorm_fp32=True, loss_scale=128.0, verbosity=is_master) if world > 1: model = DistributedDataParallel(model) model.train() if 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return gamma**len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer.optimizer if mixed_precision else optimizer, schedule) # Prepare dataset if verbose: print('Preparing dataset...') data_iterator = (DaliDataIterator if use_dali else DataIterator)( path, jitter, max_size, batch_size, stride, world, annotations, training=True) if verbose: print(data_iterator) if verbose: print(' device: {} {}'.format( world, 'cpu' if not torch.cuda.is_available() else 'gpu' if world == 1 else 'gpus')) print(' batch: {}, precision: {}'.format( batch_size, 'mixed' if mixed_precision else 'full')) print('Training model for {} iterations...'.format(iterations)) # Create TensorBoard writer if logdir is not None: from tensorboardX import SummaryWriter if is_master and verbose: print('Writing TensorBoard logs to: {}'.format(logdir)) writer = SummaryWriter(log_dir=logdir) profiler = Profiler(['train', 'fw', 'bw']) iteration = state.get('iteration', 0) while iteration < iterations: cls_losses, box_losses = [], [] for i, (data, target) in enumerate(data_iterator): scheduler.step(iteration) # Forward pass profiler.start('fw') optimizer.zero_grad() cls_loss, box_loss = model([data, target]) del data profiler.stop('fw') # Backward pass profiler.start('bw') with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # Reduce all losses cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean( ).clone() if world > 1: torch.distributed.all_reduce(cls_loss) torch.distributed.all_reduce(box_loss) cls_loss /= world box_loss /= world if is_master: cls_losses.append(cls_loss) box_losses.append(box_loss) if is_master and not isfinite(cls_loss + box_loss): raise RuntimeError('Loss is diverging!\n{}'.format( 'Try lowering the learning rate.')) del cls_loss, box_loss profiler.stop('bw') iteration += 1 profiler.bump('train') if is_master and (profiler.totals['train'] > 60 or iteration == iterations): focal_loss = torch.stack(list(cls_losses)).mean().item() box_loss = torch.stack(list(box_losses)).mean().item() learning_rate = optimizer.param_groups[0]['lr'] if verbose: msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) msg += ' focal loss: {:.3f}'.format(focal_loss) msg += ', box loss: {:.3f}'.format(box_loss) msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format( profiler.means['fw'], profiler.means['bw']) msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) msg += ', lr: {:.2g}'.format(learning_rate) print(msg, flush=True) if logdir is not None: writer.add_scalar('focal_loss', focal_loss, iteration) writer.add_scalar('box_loss', box_loss, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) del box_loss, focal_loss if metrics_url: post_metrics( metrics_url, { 'focal loss': mean(cls_losses), 'box loss': mean(box_losses), 'im_s': batch_size / profiler.means['train'], 'lr': learning_rate }) # Save model weights state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }) with ignore_sigint(): nn_model.save(state) profiler.reset() del cls_losses[:], box_losses[:] if val_annotations and (iteration == iterations or iteration % val_iterations == 0): infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, verbose=False) model.train() if iteration == iterations: break if logdir is not None: writer.close()
def main(): global args, best_prec1 args = parser.parse_args() model = create_model(args.arch, args.pretrained, args.finetune, num_classes=args.num_classes) # define loss function (criterion) and optimizer criterion = CrossEntropyLoss().cuda() optimizer = SGD( filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): load_model_from_checkpoint(args, model, optimizer) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # data loading train_path = os.path.join(args.data, 'train') test_path = os.path.join(args.data, 'test') if os.path.exists(train_path): train_loader = read_fer2013_data(train_path, dataset_type='train', batch_size=args.batch_size, num_workers=args.workers) if os.path.exists(test_path): test_loader = read_fer2013_data(test_path, dataset_type='test', batch_size=args.batch_size, num_workers=args.workers) if args.evaluate: test(test_loader, model, criterion, args.print_freq) return summary_writer = SummaryWriter() for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(args.lr, optimizer, epoch, args.lr_decay, args.lr_decay_freq) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args.print_freq, summary_writer) # evaluate on test set prec1 = test(test_loader, model, criterion, args.print_freq) # remember best prec@1 and save all checkpoints is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, is_best) summary_writer.close()
def train_model(train_loader, test_loader, device, lr, epochs, output_path, valid_loader=False): model = CNN().to(device) optimizer = SGD(model.parameters(), lr=lr) average_loss_train = [] average_loss_test = [] accuracy_train = [] accuracy_test = [] for epoch in range(epochs): model.train() correct_train, loss_train, _ = loop_dataset(model, train_loader, device, optimizer) print( f'Epoch {epoch} : average train loss - {np.mean(loss_train)}, train accuracy - {correct_train}' ) average_loss_train.append(np.mean(loss_train)) accuracy_train.append(correct_train) model.eval() correct_test, loss_test, _ = loop_dataset(model, test_loader, device) print( f'Epoch {epoch} : average test loss - {np.mean(loss_test)}, test accuracy - {correct_test}' ) average_loss_test.append(np.mean(loss_test)) accuracy_test.append(correct_test) model.eval() for i in range(0, len(model.layers)): model.layers[i].register_forward_hook(forward_hook) if valid_loader: correct_valid, _, output = loop_dataset(model, valid_loader, device) print('\033[99m' + f'Accuracy on VALID test: {correct_valid}' + '\033[0m') checkpoint = { 'model': CNN(), 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, os.path.join(output_path, 'checkpoint.pth')) plt.figure() plt.plot(range(epochs), average_loss_train, lw=0.3, c='g') plt.plot(range(epochs), average_loss_test, lw=0.3, c='r') plt.legend(['train loss', 'test_loss']) plt.xlabel('#Epoch') plt.ylabel('Loss') plt.savefig(jpath(output_path, 'loss.png')) plt.figure() plt.plot(range(epochs), accuracy_train, lw=0.3, c='g') plt.plot(range(epochs), accuracy_test, lw=0.3, c='r') plt.legend(['train_acc', 'test_acc']) plt.xlabel('#Epoch') plt.ylabel('Accuracy') plt.savefig(jpath(output_path, 'accuracy.png'))