class TrainModel(): def name(self): return 'Train Model' def initialize(self, opt): self.opt = opt self.opt.imageSize = self.opt.imageSize if len( self.opt.imageSize) == 2 else self.opt.imageSize * 2 self.gpu_ids = '' self.batchSize = self.opt.batchSize self.checkpoints_path = os.path.join(self.opt.checkpoints, self.opt.name) self.scheduler = None self.create_save_folders() # criterion to evaluate the val split self.criterion_eval = MSEScaledError() self.mse_scaled_error = MSEScaledError() self.opt.print_freq = self.opt.display_freq self.visualizer = Visualizer(opt) if self.opt.resume and self.opt.display_id > 0: self.load_plot_data() elif opt.train: self.start_epoch = 1 self.best_val_error = 999.9 # self.print_save_options() # Logfile self.logfile = open(os.path.join(self.checkpoints_path, 'logfile.txt'), 'a') if opt.validate: self.logfile_val = open( os.path.join(self.checkpoints_path, 'logfile_val.txt'), 'a') # Prepare a random seed that will be the same for everyone # opt.manualSeed = random.randint(1, 10000) # fix seed # print("Random Seed: ", opt.manualSeed) # # random.seed(opt.manualSeed) # torch.manual_seed(opt.manualSeed) self.random_seed = 123 random.seed(self.random_seed) torch.cuda.manual_seed_all(self.random_seed) torch.manual_seed(self.random_seed) if opt.cuda: self.cuda = torch.device( 'cuda:0') # set externally. ToDo: set internally torch.cuda.manual_seed(self.random_seed) # uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms. cudnn.benchmark = self.opt.use_cudnn_benchmark # using too much memory - use when not in astroboy cudnn.enabled = True if not opt.train and not opt.test and not opt.resume: raise Exception("You have to set --train or --test") if torch.cuda.is_available and not opt.cuda: print( "WARNING: You have a CUDA device, so you should run WITHOUT --cpu" ) if not torch.cuda.is_available and opt.cuda: raise Exception("No GPU found, run WITH --cpu") def set_input(self, input): self.input = input def create_network(self): netG = networks.define_G(input_nc=self.opt.input_nc, output_nc=self.opt.output_nc, ngf=64, net_architecture=self.opt.net_architecture, opt=self.opt, gpu_ids='') if self.opt.cuda: netG = netG.cuda() return netG def get_optimizerG(self, network, lr, weight_decay=0.0): generator_params = filter(lambda p: p.requires_grad, network.parameters()) return optim.Adam(generator_params, lr=lr, betas=(self.opt.beta1, 0.999), weight_decay=weight_decay) def get_checkpoint(self, epoch): pass def train_batch(self): """Each method has a different implementation""" pass def display_gradients_norms(self): return 'nothing yet' def get_current_errors_display(self): pass def get_regression_criterion(self): if self.opt.regression_loss == 'L1': return nn.L1Loss() def get_variable(self, tensor, requires_grad=False): if self.opt.cuda: tensor = tensor.cuda() return Variable(tensor, requires_grad=requires_grad) def restart_variables(self): self.it = 0 self.rmse = 0 self.n_images = 0 def train(self, data_loader, val_loader=None): self.data_loader = data_loader self.len_data_loader = len( self.data_loader) # check if gonna use elsewhere self.total_iter = 0 for epoch in range(self.start_epoch, self.opt.nEpochs): self.restart_variables() self.data_iter = iter(self.data_loader) # self.pbar = tqdm(range(self.len_data_loader)) self.pbar = range(self.len_data_loader) # while self.it < self.len_data_loader: for self.it in self.pbar: if self.opt.optim == 'SGD': self.scheduler.step() self.total_iter += self.opt.batchSize self.netG.train(True) iter_start_time = time.time() self.train_batch() d_time = (time.time() - iter_start_time) / self.opt.batchSize # print errors self.print_current_errors(epoch, d_time) # display errors self.display_current_results(epoch) # Validate self.evaluate(val_loader, epoch) # save checkpoint self.save_checkpoint(epoch, is_best=0) self.logfile.close() if self.opt.validate: self.logfile_val.close() def get_next_batch(self): # self.it += 1 # important for GANs rgb_cpu, depth_cpu = self.data_iter.next() # depth_cpu = depth_cpu[0] self.input.data.resize_(rgb_cpu.size()).copy_(rgb_cpu) # self.target.data.resize_(depth_cpu.size()).copy_(depth_cpu) def apply_valid_pixels_mask(self, *data, value=0.0): # self.nomask_outG = data[0].data # for displaying purposes mask = (data[1].data > value).to(self.cuda, dtype=torch.float32) masked_data = [] for d in data: masked_data.append(d * mask) return masked_data, mask.sum() def update_learning_rate(self, epoch): if epoch > self.opt.niter_decay and self.opt.use_cgan: # but independs if conditional or not # Linear decay for discriminator [self.opt.d_lr, self.optimD] = self._update_learning_rate(self.opt.niter_decay, self.opt.d_lr, self.optimD) [self.opt.lr, self.optimG] = self._update_learning_rate(self.opt.niter_decay, self.opt.lr, self.optimG) def _update_learning_rate(self, niter_decay, old_lr, optim): lr = old_lr - old_lr / niter_decay for param_group in optim.param_groups: param_group['lr'] = lr return lr, optim # CONTROL FUNCTIONS OF THE ARCHITECTURE def _get_plot_data_filename(self, phase): return os.path.join( self.checkpoints_path, 'plot_data' + ('' if phase == 'train' else '_' + phase) + '.p') def save_static_plot_image(): return None def save_interactive_plot_image(): return None def _save_plot_data(self, plot_data, filename): # save pickle.dump(plot_data, open(filename, 'wb')) def save_plot_data(self): self._save_plot_data(self.visualizer.plot_data, self._get_plot_data_filename('train')) if self.opt.validate and self.total_iter > self.opt.val_freq: self._save_plot_data(self.visualizer.plot_data_val, self._get_plot_data_filename('val')) def _load_plot_data(self, filename): # verify if file exists if not os.path.isfile(filename): raise Exception( 'In _load_plot_data file {} doesnt exist.'.format(filename)) else: return pickle.load(open(filename, "rb")) def load_plot_data(self): self.visualizer.plot_data = self._load_plot_data( self._get_plot_data_filename('train')) if self.opt.validate: self.visualizer.plot_data_val = self._load_plot_data( self._get_plot_data_filename('val')) def save_checkpoint(self, epoch, is_best): if epoch % self.opt.save_checkpoint_freq == 0 or is_best: checkpoint = self.get_checkpoint(epoch) checkpoint_filename = '{}/{:04}.pth.tar'.format( self.checkpoints_path, epoch) self._save_checkpoint( checkpoint, is_best=is_best, filename=checkpoint_filename ) # standart is_best=0 here cause we didn' evaluate on validation data # save plot data as well def _save_checkpoint(self, state, is_best, filename): print("Saving checkpoint...") # uncomment next 2 lines if we still want per epoch torch.save(state, filename) shutil.copyfile( filename, os.path.join(os.path.dirname(filename), 'latest.pth.tar')) # comment next 2 lines if necessary if using last two lines # filename = os.path.join(self.checkpoints_path, 'latest.pth.tar') # torch.save(state, os.path.join(self.checkpoints_path, 'latest.pth.tar')) if is_best: shutil.copyfile( filename, os.path.join(self.checkpoints_path, 'best.pth.tar')) def create_save_folders(self): if self.opt.train: os.system('mkdir -p {0}'.format(self.checkpoints_path)) # if self.opt.save_samples: # subfolders = ['input', 'target', 'results', 'output'] # self.save_samples_path = os.path.join('results/train_results/', self.opt.name) # for subfolder in subfolders: # path = os.path.join(self.save_samples_path, subfolder) # os.system('mkdir -p {0}'.format(path)) # if self.opt.test: # self.save_samples_path = os.path.join('results/test_results/', self.opt.name) # self.save_samples_path = os.path.join(self.save_samples_path, self.opt.epoch) # for subfolder in subfolders: # path = os.path.join(self.save_samples_path, subfolder) # os.system('mkdir -p {0}'.format(path)) def print_save_options(self): options_file = open(os.path.join(self.checkpoints_path, 'options.txt'), 'w') args = dict((arg, getattr(self.opt, arg)) for arg in dir(self.opt) if not arg.startswith('_')) print('---Options---') for k, v in sorted(args.items()): option = '{}: {}'.format(k, v) # print options print(option) # save options in file options_file.write(option + '\n') options_file.close() def mean_errors(self): pass def get_current_errors(self): pass def print_current_errors(self, epoch, d_time): if self.total_iter % self.opt.print_freq == 0: self.mean_errors() errors = self.get_current_errors() message = self.visualizer.print_errors(errors, epoch, self.it, self.len_data_loader, d_time) # self.pbar.set_description(message) print(message) # self.pbar.refresh() # def print_epoch_error(error): # pass def get_current_visuals(self): pass def display_current_results(self, epoch): if self.opt.display_id > 0 and self.total_iter % self.opt.display_freq == 0: errors = self.get_current_errors_display() self.visualizer.display_errors( errors, epoch, float(self.it) / self.len_data_loader) visuals = self.get_current_visuals() self.visualizer.display_images(visuals, epoch) # save printed errors to logfile self.visualizer.save_errors_file(self.logfile) def evaluate(self, data_loader, epoch): if self.opt.validate and self.total_iter % self.opt.val_freq == 0: val_error = self.get_eval_error(data_loader, self.netG, self.criterion_eval, epoch) # errors = OrderedDict([('LossL1', self.e_reg if self.opt.reg_type == 'L1' else self.L1error), # ('ValError', val_error.item())]) errors = OrderedDict([('RMSE', self.rmse_epoch), ('RMSEVal', val_error)]) self.visualizer.display_errors(errors, epoch, float(self.it) / self.len_data_loader, phase='val') message = self.visualizer.print_errors(errors, epoch, self.it, len(data_loader), 0) print('[Validation] ' + message) self.visualizer.save_errors_file(self.logfile_val) self.save_plot_data() # save best models is_best = self.best_val_error > val_error if is_best: # and not self.opt.not_save_val_model: print("Updating BEST model (epoch {}, iters {})\n".format( epoch, self.total_iter)) self.best_val_error = val_error self.save_checkpoint(epoch, is_best) def get_eval_error(self, val_loader, model, criterion, epoch): """ Validate every self.opt.val_freq epochs """ # no need to switch to model.eval because we want to keep dropout layers. Do I gave to ignore batch norm layers? cumulated_rmse = 0 batchSize = 1 input = self.get_variable(torch.FloatTensor(batchSize, 3, self.opt.imageSize[0], self.opt.imageSize[1]), requires_grad=False) mask = self.get_variable(torch.FloatTensor(batchSize, 1, self.opt.imageSize[0], self.opt.imageSize[1]), requires_grad=False) target = self.get_variable( torch.FloatTensor(batchSize, 1, self.opt.imageSize[0], self.opt.imageSize[1])) # model.eval() model.train(False) pbar_val = tqdm(val_loader) for i, (rgb_cpu, depth_cpu) in enumerate(pbar_val): pbar_val.set_description('[Validation]') input.data.resize_(rgb_cpu.size()).copy_(rgb_cpu) target.data.resize_(depth_cpu.size()).copy_(depth_cpu) if self.opt.use_padding: from torch.nn import ReflectionPad2d self.opt.padding = self.get_padding_image(input) input = ReflectionPad2d(self.opt.padding)(input) target = ReflectionPad2d(self.opt.padding)(target) # get output of the network with torch.no_grad(): outG = model.forward(input) # apply mask nomask_outG = outG.data # for displaying purposes mask_ByteTensor = self.get_mask(target.data) mask.data.resize_(mask_ByteTensor.size()).copy_(mask_ByteTensor) outG = outG * mask target = target * mask cumulated_rmse += sqrt(criterion(outG, target, mask, no_mask=False)) if (i == 1): self.visualizer.display_images(OrderedDict([ ('input', input.data), ('gt', target.data), ('output', nomask_outG) ]), epoch='val {}'.format(epoch), phase='val') return cumulated_rmse / len(val_loader) def get_mask(self, data, value=0.0): return (target.data > 0.0) def get_padding(self, dim): final_dim = (dim // 32 + 1) * 32 return final_dim - dim def get_padding_image(self, img): # get tensor dimensions h, w = img.size()[2:] w_pad, h_pad = self.get_padding(w), self.get_padding(h) pwr = w_pad // 2 pwl = w_pad - pwr phb = h_pad // 2 phu = h_pad - phb # pwl, pwr, phu, phb return (pwl, pwr, phu, phb) def adjust_learning_rate(self, initial_lr, optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = initial_lr * (0.1**(epoch // self.opt.niter_decay)) if epoch % self.opt.niter_decay == 0: print("LEARNING RATE DECAY HERE: lr = {}".format(lr)) for param_group in optimizer.param_groups: param_group['lr'] = lr
class MTL_TOY(): def name(self): return 'MultiTask for Toy Dataset' def initialize(self, opt): print(self.name()) # set the random seeds for reproducibility self.random_seed = 123 np.random.seed(self.random_seed) torch.cuda.manual_seed_all(self.random_seed) torch.manual_seed(self.random_seed) # MultiTaskGen.initialize(self, opt) self.opt = opt self.batchSize = opt.batchSize self.n_tasks = self.opt.n_tasks self.errors = OrderedDict() if self.opt.display_id > 0: self.visualizer = Visualizer(opt) # define the sigmas, the number of tasks and the epsilons # for the toy example if self.n_tasks == 2: sigmas = [1.0, float(opt.sigma)] elif self.n_tasks > 2: # sample from normal distribution pass print('Training toy example with sigmas={}'.format(sigmas)) # n_tasks = len(sigmas) # B and epsilons are constant matrices # Information is shared in B # Epsilon contains task-specific information epsilons = np.random.normal(scale=3.5, size=(self.n_tasks, 100, 250)).astype(np.float32) B = np.random.normal(scale=10, size=(100, 250)).astype(np.float32) # initialize the data loader dataset = ToyDatasetTrainTest(sigmas, epsilons, B) # dataset = ToyDataset(sigmas, epsilons, B) # dataset_val = ToyDatasetVal(sigmas, epsilons, B) # self.opt.batchSize = 100 self.data_loader = data.DataLoader(dataset, batch_size=self.opt.batchSize, num_workers=4, shuffle=False) # self.val_loader = data.DataLoader(dataset_val, batch_size=self.opt.batchSize, num_workers=4, shuffle=False) # Alpha here is the GP coefficient lambda in the paper self.alpha = 10.0 # self.opt.alpha self.netG = self.create_network() # lr=2e-4 # if self.opt.optim == 'adam': # self.optim = torch.optim.Adam(self.netG.parameters(), lr=lr) # elif self.optim = self.get_optim() self.regression_loss = nn.L1Loss() self.cuda = torch.device('cuda:0') # self.first_epoch_test = True def get_optim(self, lr=2e-2): if self.opt.optim == 'adam': return torch.optim.Adam(self.netG.parameters(), lr=lr) elif self.opt.optim == 'sgd': return torch.optim.SGD(self.netG.parameters(), lr=lr) elif self.opt.optim == 'adagrad': return torch.optim.Adagrad(self.netG.parameters(), lr=lr) def create_network(self): from networks.mtl_toynetwork import ToyNetwork netG = ToyNetwork(self.opt.n_tasks, self.opt.mtl_method) if self.opt.cuda: netG = netG.cuda() return netG def restart_variables(self): self.it = 0 self.n_images = 0 self.losses_sum = 0 self.loss_sum = 0 self.reg_loss_sum = np.zeros(self.n_tasks) self.n_its = 0 def task_normalized_training_error(self, losses): # transform from list to numpy return np.divide(losses, self.initial_losses) def loss_ratio(self, losses): with torch.no_grad(): if self.first_epoch[0] == True: self.initial_losses = self.to_numpy( torch.stack(losses).detach()) # self.task_normalized_training_error(self.to_numpy(torch.stack(losses))) self.losses_sum = self.task_normalized_training_error( self.to_numpy(torch.stack(losses))) # self.losses_sum += self.task_normalized_training_error(self.to_numpy(torch.stack(losses))) def mean_errors(self): total_loss = [] # for i in range(self.n_tasks): # # per task # # mse_epoch = self.reg_loss_sum[i] / self.n_images # # self.set_current_errors_string('MSETask{}'.format(i), mse_epoch) # total_loss.append(self.losses_sum[i] / self.n_images) # # self.set_current_errors_string('NLossTask{}'.format(i), total_loss[i]) # self.set_current_errors(NTotalLoss=np.array(total_loss).sum()/self.n_tasks) mean_loss = self.loss_sum / self.n_its self.set_current_errors(mean_loss=mean_loss) self.set_current_errors(NTotalLoss=np.array(self.losses_sum).sum() / self.n_tasks) def train(self): self.len_data_loader = len( self.data_loader) # check if gonna use elsewhere self.total_iter = 0 for epoch in range(1, self.opt.nEpochs): self.epoch = epoch self.first_epoch = [True if epoch == 1 else False] self.restart_variables() self.data_iter = iter(self.data_loader) # self.pbar = tqdm(range(self.len_data_loader)) self.pbar = range(self.len_data_loader) # while self.it < self.len_data_loader: for self.it in self.pbar: self.netG.train(True) iter_start_time = time.time() self.train_batch() self.mean_errors() d_time = (time.time() - iter_start_time) / self.opt.batchSize self.total_iter += self.opt.batchSize # change because it may be different # Validate self.evaluate(epoch) # print errors self.print_current_errors(epoch, d_time) # display errors self.display_current_results(epoch) # save checkpoint # self.save_checkpoint(epoch, is_best=0) def get_val_error(self): model = self.netG.train(False) # len_val_loader = len(self.val_loader) pbar_val = tqdm(self.val_loader) norm_losses = 0 n_images = 0 with torch.no_grad(): for i, (input_cpu, target_cpu) in enumerate(pbar_val): input_data = input_cpu.to(self.cuda) output = model(input_data) task_loss = [] for i_task in range(self.n_tasks): target = target_cpu[:, i_task, :].to(self.cuda) task_loss.append( self.regression_loss(output[:, i_task, :], target)) np_task_losses = self.to_numpy(torch.stack(task_loss)) if self.first_epoch_test == True: self.val_initial_error = np_task_losses self.first_epoch_test = False # n_images += input_cpu.shape[0] norm_losses += np.divide(np_task_losses, self.val_initial_error) return norm_losses.mean() def evaluate(self, epoch): if self.opt.validate and (epoch - 1) % self.opt.val_freq == 0: val_error = self.get_val_error() self.val_error = OrderedDict([('Val error', val_error)]) print('Eval error is: {}'.format(val_error)) def to_numpy(self, data): return data.data.cpu().numpy() def train_batch(self): pass def set_current_errors(self, **k_dict_elements): for key, value in k_dict_elements.items(): self.errors.update([(key, value)]) def set_current_errors_string(self, key, value): self.errors.update([(key, value)]) def get_current_errors(self): return self.errors def get_current_errors_display(self): return self.errors def display_current_results(self, epoch): if self.opt.display_id > 0 and ( epoch - 1) % self.opt.display_freq == 0 and self.it == ( len(self.data_loader) - 1): errors = self.get_current_errors_display() self.visualizer.display_errors( errors, epoch, float(self.it) / self.len_data_loader) if self.opt.validate and (epoch - 1) % self.opt.val_freq == 0: self.visualizer.display_errors(self.val_error, epoch, float(self.it) / self.len_data_loader, phase='val') def print_current_errors(self, epoch, d_time): # if self.total_iter % self.opt.print_freq == 0: if self.opt.display_id > 0 and ( epoch - 1) % self.opt.display_freq == 0 and self.it == ( len(self.data_loader) - 1): errors = self.get_current_errors() message = self.visualizer.print_errors(errors, epoch, self.it, self.len_data_loader, d_time) print(message)