class VisdomLinePlotter(object): def __init__(self, env_name='main', logging_path=None): self.viz = Visdom(log_to_filename=logging_path) if os.path.isfile(logging_path): self.viz.replay_log(logging_path) self.env = env_name self.postfix = '' self.plots = {} def plot(self, var_name, split_name, title_name, x, y): title_name = '_'.join([title_name, self.postfix]) if title_name not in self.plots: self.plots[title_name] = self.viz.line(X=np.array([x, x]), Y=np.array([y, y]), env=self.env, opts=dict( legend=[split_name], title=title_name, xlabel='iterations', ylabel=var_name, )) else: self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[title_name], name=split_name, update='append') def set_cv(self, cv): self.postfix = '_'.join(['cv', str(cv)])
class VisdomLogger: def __init__(self, host='127.0.0.1', port=8097, env='main', log_path=None): from visdom import Visdom import json logger.info(f"using visdom on http://{host}:{port} env={env}") self.env = env self.viz = Visdom(server=f"http://{host}", port=port, env=env, log_to_filename=log_path) self.windows = dict() # if prev log exists if log_path.exists(): self.viz.replay_log(log_path) wins = json.loads(self.viz.get_window_data(win=None, env=env)) for k, v in wins.items(): names = [int(x['name']) for x in v['content']['data']] name = str(max(names) + 1) self.windows[v['title']] = {'win': v['id'], 'name': name} def add_plot(self, title, **kwargs): if title not in self.windows: self.windows[title] = { 'win': None, 'name': '1', } self.windows[title]['opts'] = { 'title': title, } self.windows[title]['opts'].update(kwargs) def add_point(self, title, x, y): X, Y = torch.FloatTensor([ x, ]), torch.FloatTensor([ y, ]) if title not in self.windows: self.add_plot(title) if self.windows[title]['win'] is None: w = self.viz.line(Y=Y, X=X, opts=self.windows[title]['opts'], name=self.windows[title]['name']) self.windows[title]['win'] = w else: self.viz.line(Y=Y, X=X, update='append', win=self.windows[title]['win'], name=self.windows[title]['name'])
class VisdomLogger: def __init__(self, host='127.0.0.1', port=8097, env='main', log_path=None, rank=None): from visdom import Visdom logger.debug(f"using visdom on http://{host}:{port} env={env}") self.env = env self.rank = rank self.viz = Visdom(server=f"http://{host}", port=port, env=env, log_to_filename=log_path) self.windows = dict() # if prev log exists if log_path is not None and log_path.exists() and (rank is None or rank == 0): self.viz.replay_log(log_path) def _get_win(self, title, type): import json win_data = json.loads(self.viz.get_window_data(win=None, env=self.env)) wins = [(w, v) for w, v in win_data.items() if v['title'] == title and v['type'] == type] if wins: handle, value = sorted(wins, key=lambda x: x[0])[0] return handle, value['content'] else: return None, None def _get_rank0_win(self, title, type): if self.rank is not None and self.rank > 0: # wait and fetch the window handle until rank=0 client generates new window for _ in range(10): handle, content = self._get_win(title, type) if handle is not None: return handle, content time.sleep(0.5) else: logger.error( "couldn't get a proper window handle from the visdom server" ) raise RuntimeError else: return self._get_win(title, type) def _new_window(self, cmd, title, **cmd_args): if cmd == self.viz.images: types = ("image", None) elif cmd == self.viz.scatter or cmd == self.viz.line: types = ("plot", "scatter") elif cmd == self.viz.heatmap: types = ("plot", "heatmap") else: types = ("plot", None) handle, content = self._get_rank0_win(title, types[0]) if handle is None: if "opts" in cmd_args: cmd_args['opts'].update({ "title": title, }) else: cmd_args['opts'] = { "title": title, } if types == ("plot", "scatter"): name = f"1_{self.rank}" if self.rank is not None else "1" handle = cmd(name=name, **cmd_args) else: name = None handle = cmd(**cmd_args) else: if types == ("plot", "scatter"): name = max([ int(x['name'].partition('_')[0]) for x in content['data'] ]) name = f"{name+1}_{self.rank}" if self.rank is not None else f"{name+1}" cmd(win=handle, name=name, update="append", **cmd_args) else: name = None handle = cmd(win=handle, **cmd_args) self.windows[title] = { 'handle': handle, 'name': name, 'opts': cmd_args["opts"], } def add_point(self, title, x, y, **kwargs): X, Y = torch.FloatTensor([ x, ]), torch.FloatTensor([ y, ]) if title not in self.windows: cmd = self.viz.line self._new_window(cmd, title, X=X, Y=Y, opts=kwargs) else: self.windows[title]['opts'].update(kwargs) handle = self.windows[title]['handle'] name = self.windows[title]['name'] opts = self.windows[title]['opts'] self.viz.line(win=handle, update='append', Y=Y, X=X, name=name, opts=opts) def plot_heatmap(self, title, tensor, **kwargs): if title not in self.windows: cmd = self.viz.heatmap self._new_window(cmd, title, X=tensor, opts=kwargs) else: self.windows[title]['opts'].update(kwargs) handle = self.windows[title]['handle'] opts = self.windows[title]['opts'] self.viz.heatmap(win=handle, X=tensor, opts=opts) def plot_images(self, title, tensor, nrow, **kwargs): if title not in self.windows: cmd = self.viz.images self._new_window(cmd, title, tensor=tensor, nrow=nrow, opts=kwargs) else: self.windows[title]['opts'].update(kwargs) handle = self.windows[title]['handle'] opts = self.windows[title]['opts'] self.viz.images(win=handle, tensor=tensor, nrow=nrow, opts=opts)
import argparse from visdom import Visdom parser = argparse.ArgumentParser(description='Visdom Log Writer.') parser.add_argument('--visdom-url', type=str, required=True, help='visdom URL for graphs, needs http://url') parser.add_argument('--visdom-port', type=int, required=True, help='visdom port for graphs') parser.add_argument('--log-file', type=str, required=True, help='the file to (default: None)') args = parser.parse_args() if __name__ == "__main__": visdom = Visdom(server=args.visdom_url, port=args.visdom_port, use_incoming_socket=False, raise_exceptions=False) visdom.replay_log(args.log_file)
class Plot(object): def __init__(self, title="", env_name="", config=None, port=8080): self.env_name = env_name if env_name else title self.viz = Visdom(port=port, env=self.env_name) # self.viz.close() self.windows = {} self.title = title self.config = config def register_plot(self, name, xlabel, ylabel, plot_type="line", ymax=None): self.windows[name] = {"xlabel": xlabel, "ylabel": ylabel, "title": name, "plot_type": plot_type} self.windows[name]["opts"] = dict(title=name, markersize=5, xlabel=xlabel, ylabel=ylabel) if ymax is not None: self.windows[name]["opts"]["layoutopts"] = dict(plotly=dict(yaxis=dict(range=[0, ymax]))) def update_plot(self, plot_name, x, y, **kwargs): # Create plot if not registered try: plot_d = self.windows[plot_name] except: warnings.warn("Plot not found, creating new plot") plot_d = {"xlabel": "X", "ylabel": "Y", "plot_type": "scatter"} plotter = self.viz.scatter if plot_d["plot_type"] == "scatter" else self.viz.line # WHY WAS "Y" A NESTED LIST??? # data = {"X": np.asarray(x), "Y": np.asarray([y])} if plot_d["plot_type"] == "line" else {"X": np.asarray([x, y])} x = np.asarray(x) if len(x) < len(y): warnings.warn("X coords not found, interpolating") if x[0] == 0 and len(x) > 1: x[0] = x[1] - .001 additional_x = np.linspace(0, x[0], len(y) - len(x)) x = np.r_[additional_x, np.asarray(x)] data = {"X": x, "Y": np.asarray(y)} if plot_d["plot_type"] == "line" else {"X": np.asarray([x, y])} ## Update plot if "plot" in plot_d.keys(): plotter( **data, win=plot_d["plot"], update="append" ) else: # Create new plot win = plotter( **data, opts=plot_d["opts"], **kwargs ) plot_d["plot"] = win self.windows["name"] = plot_d # LOADING def load_log(self, path): self.viz.replay_log(path) def load_all_env(self, root, keyword="visdom"): for d, ss, fs in os.walk(root): for f in fs: full_env = os.path.join(d, f) # Don't load "BSF" graphs, just complete graphs if full_env[-5:] == ".json" and keyword in full_env and f != "losses.json" and "BSF_" not in full_env: print("Loading {}".format(full_env)) self.viz.replay_log(full_env) # viz.load load the environment to viz def save_env(self, file_path=None, current_env=None, new_env=None): if file_path is None: file_path = os.path.join(self.config["results_dir"], "visdom.json") if current_env is None: current_env = self.env_name new_env = current_env if new_env is None else new_env # self.viz = Visdom(env=current_env) # get current env data = json.loads(self.viz.get_window_data()) if len(data) == 0: print("NOTHING HAS BEEN SAVED: NOTHING IN THIS VISDOM ENV - DOES IT EXIST ?") return file = open(file_path, 'w+') for datapoint in data.values(): output = { 'win': datapoint['id'], 'eid': new_env, 'opts': {} } if datapoint['type'] != "plot": output['data'] = [{'content': datapoint['content'], 'type': datapoint['type']}] if datapoint['height'] is not None: output['opts']['height'] = datapoint['height'] if datapoint['width'] is not None: output['opts']['width'] = datapoint['width'] else: output['data'] = datapoint['content']["data"] output['layout'] = datapoint['content']["layout"] to_write = json.dumps(["events", output]) file.write(to_write + '\n') file.close()
from visdom import Visdom import argparse parser = argparse.ArgumentParser() parser.add_argument('--log_file', type=str, default='', help='log file') opt = parser.parse_args() viz = Visdom(port=8097) viz.replay_log(opt.log_file)
class Model: def __init__(self, opt): self.opt = opt self.device = torch.device("cuda" if opt.ngpu else "cpu") self.model, self.classifier = models.get_model(opt.net_type, opt.classifier_type, opt.pretrained, int(opt.nclasses)) self.model = self.model.to(self.device) self.classifier = self.classifier.to(self.device) if opt.ngpu > 1: self.model = nn.DataParallel(self.model) self.loss = models.init_loss(opt.loss_type) self.loss = self.loss.to(self.device) self.optimizer = utils.get_optimizer(self.model, self.opt) self.lr_scheduler = utils.get_lr_scheduler(self.opt, self.optimizer) self.alpha_scheduler = utils.get_margin_alpha_scheduler(self.opt) self.train_loader = datasets.generate_loader(opt, 'train') self.test_loader = datasets.generate_loader(opt, 'val') self.epoch = 0 self.best_epoch = False self.training = False self.state = {} self.train_loss = utils.AverageMeter() self.test_loss = utils.AverageMeter() self.batch_time = utils.AverageMeter() self.test_metrics = utils.ROCMeter() self.best_test_loss = utils.AverageMeter() self.best_test_loss.update(np.array([np.inf])) self.visdom_log_file = os.path.join(self.opt.out_path, 'log_files', 'visdom.log') self.vis = Visdom(port=opt.visdom_port, log_to_filename=self.visdom_log_file, env=opt.exp_name + '_' + str(opt.fold)) self.vis_loss_opts = { 'xlabel': 'epoch', 'ylabel': 'loss', 'title': 'losses', 'legend': ['train_loss', 'val_loss'] } self.vis_tpr_opts = { 'xlabel': 'epoch', 'ylabel': 'tpr', 'title': 'val_tpr', 'legend': ['tpr@fpr10-2', 'tpr@fpr10-3', 'tpr@fpr10-4'] } self.vis_epochloss_opts = { 'xlabel': 'epoch', 'ylabel': 'loss', 'title': 'epoch_losses', 'legend': ['train_loss', 'val_loss'] } def train(self): # Init Log file if self.opt.resume: self.log_msg('resuming...\n') # Continue training from checkpoint self.load_checkpoint() else: self.log_msg() for epoch in range(self.epoch, self.opt.num_epochs): self.epoch = epoch #freezing model if self.opt.freeze_epoch: if epoch < self.opt.freeze_epoch: if self.opt.ngpu > 1: for param in self.model.module.parameters(): param.requires_grad = False else: for param in self.model.parameters(): param.requires_grad = False elif epoch == self.opt.freeze_epoch: if self.opt.ngpu > 1: for param in self.model.module.parameters(): param.requires_grad = True else: for param in self.model.parameters(): param.requires_grad = True self.lr_scheduler.step() self.train_epoch() self.test_epoch() self.log_epoch() self.vislog_epoch() self.create_state() self.save_state() def train_epoch(self): """ Trains model for 1 epoch """ self.model.train() self.classifier.train() self.training = True torch.set_grad_enabled(self.training) self.train_loss.reset() self.batch_time.reset() time_stamp = time.time() self.batch_idx = 0 for batch_idx, (rgb_data, depth_data, ir_data, target) in enumerate(self.train_loader): self.batch_idx = batch_idx rgb_data = rgb_data.to(self.device) depth_data = depth_data.to(self.device) ir_data = ir_data.to(self.device) target = target.to(self.device) self.optimizer.zero_grad() output = self.model(rgb_data, depth_data, ir_data) if isinstance(self.classifier, nn.Linear): output = self.classifier(output) else: if self.alpha_scheduler: alpha = self.alpha_scheduler.get_alpha(self.epoch) output = self.classifier(output, target, alpha=alpha) else: output = self.classifier(output, target) if self.opt.loss_type == 'bce': target = target.float() loss_tensor = self.loss(output.squeeze(), target) else: loss_tensor = self.loss(output, target) loss_tensor.backward() self.optimizer.step() self.train_loss.update(loss_tensor.item()) self.batch_time.update(time.time() - time_stamp) time_stamp = time.time() self.log_batch(batch_idx) self.vislog_batch(batch_idx) def test_epoch(self): """ Calculates loss and metrics for test set """ self.training = False torch.set_grad_enabled(self.training) self.model.eval() self.classifier.eval() self.batch_time.reset() self.test_loss.reset() self.test_metrics.reset() time_stamp = time.time() for batch_idx, (rgb_data, depth_data, ir_data, target) in enumerate(self.test_loader): rgb_data = rgb_data.to(self.device) depth_data = depth_data.to(self.device) ir_data = ir_data.to(self.device) target = target.to(self.device) output = self.model(rgb_data, depth_data, ir_data) output = self.classifier(output) if self.opt.loss_type == 'bce': target = target.float() loss_tensor = self.loss(output.squeeze(), target) else: loss_tensor = self.loss(output, target) self.test_loss.update(loss_tensor.item()) if self.opt.loss_type == 'cce' or self.opt.loss_type == 'focal_loss': output = torch.nn.functional.softmax(output, dim=1) elif self.opt.loss_type == 'bce': output = torch.sigmoid(output) self.test_metrics.update(target.cpu().numpy(), output.cpu().numpy()) self.batch_time.update(time.time() - time_stamp) time_stamp = time.time() self.log_batch(batch_idx) #self.vislog_batch(batch_idx) if self.opt.debug and (batch_idx == 10): print('Debugging done!') break self.best_epoch = self.test_loss.avg < self.best_test_loss.val if self.best_epoch: # self.best_test_loss.val is container for best loss, # n is not used in the calculation self.best_test_loss.update(self.test_loss.avg, n=0) def calculate_metrics(self, output, target): """ Calculates test metrix for given batch and its input """ t = target o = output if self.opt.loss_type == 'bce': accuracy = (t.byte() == (o > 0.5)).float().mean(0).cpu().numpy() batch_result.append(binary_accuracy) elif self.opt.loss_type == 'cce': top1_accuracy = (torch.argmax(o, 1) == t).float().mean().item() batch_result.append(top1_accuracy) else: raise Exception('This loss function is not implemented yet') return batch_result def log_batch(self, batch_idx): if batch_idx % self.opt.log_batch_interval == 0: cur_len = len(self.train_loader) if self.training else len( self.test_loader) cur_loss = self.train_loss if self.training else self.test_loss output_string = 'Train ' if self.training else 'Test ' output_string += 'Epoch {}[{:.2f}%]: [{:.2f}({:.3f}) s]\t'.format( self.epoch, 100. * batch_idx / cur_len, self.batch_time.val, self.batch_time.avg) loss_i_string = 'Loss: {:.5f}({:.5f})\t'.format( cur_loss.val, cur_loss.avg) output_string += loss_i_string if not self.training: output_string += '\n' metrics_i_string = 'Accuracy: {:.5f}\t'.format( self.test_metrics.get_accuracy()) output_string += metrics_i_string print(output_string) def vislog_batch(self, batch_idx): if batch_idx % self.opt.log_batch_interval == 0: loader_len = len(self.train_loader) if self.training else len( self.test_loader) cur_loss = self.train_loss if self.training else self.test_loss loss_type = 'train_loss' if self.training else 'val_loss' x_value = self.epoch + batch_idx / loader_len y_value = cur_loss.val self.vis.line([y_value], [x_value], name=loss_type, win='losses', update='append') self.vis.update_window_opts(win='losses', opts=self.vis_loss_opts) def log_msg(self, msg=''): mode = 'a' if msg else 'w' f = open(os.path.join(self.opt.out_path, 'log_files', 'train_log.txt'), mode) f.write(msg) f.close() def log_epoch(self): """ Epoch results log string""" out_train = 'Train: ' out_test = 'Test: ' loss_i_string = 'Loss: {:.5f}\t'.format(self.train_loss.avg) out_train += loss_i_string loss_i_string = 'Loss: {:.5f}\t'.format(self.test_loss.avg) out_test += loss_i_string out_test += '\nTest: ' metrics_i_string = 'TPR@FPR=10-2: {:.4f}\t'.format( self.test_metrics.get_tpr(0.01)) metrics_i_string += 'TPR@FPR=10-3: {:.4f}\t'.format( self.test_metrics.get_tpr(0.001)) metrics_i_string += 'TPR@FPR=10-4: {:.4f}\t'.format( self.test_metrics.get_tpr(0.0001)) out_test += metrics_i_string is_best = 'Best ' if self.best_epoch else '' out_res = is_best + 'Epoch {} results:\n'.format( self.epoch) + out_train + '\n' + out_test + '\n' print(out_res) self.log_msg(out_res) def vislog_epoch(self): x_value = self.epoch self.vis.line([self.train_loss.avg], [x_value], name='train_loss', win='epoch_losses', update='append') self.vis.line([self.test_loss.avg], [x_value], name='val_loss', win='epoch_losses', update='append') self.vis.update_window_opts(win='epoch_losses', opts=self.vis_epochloss_opts) self.vis.line([self.test_metrics.get_tpr(0.01)], [x_value], name='tpr@fpr10-2', win='val_tpr', update='append') self.vis.line([self.test_metrics.get_tpr(0.001)], [x_value], name='tpr@fpr10-3', win='val_tpr', update='append') self.vis.line([self.test_metrics.get_tpr(0.0001)], [x_value], name='tpr@fpr10-4', win='val_tpr', update='append') self.vis.update_window_opts(win='val_tpr', opts=self.vis_tpr_opts) def create_state(self): self.state = { # Params to be saved in checkpoint 'epoch' : self.epoch, 'model_state_dict' : self.model.state_dict(), 'classifier_state_dict': self.classifier.state_dict(), 'best_test_loss' : self.best_test_loss, 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict(), } def save_state(self): if self.opt.log_checkpoint == 0: self.save_checkpoint('checkpoint.pth') else: if (self.epoch % self.opt.log_checkpoint == 0): self.save_checkpoint('model_{}.pth'.format(self.epoch)) def save_checkpoint( self, filename): # Save model to task_name/checkpoints/filename.pth fin_path = os.path.join(self.opt.out_path, 'checkpoints', filename) torch.save(self.state, fin_path) if self.best_epoch: best_fin_path = os.path.join(self.opt.out_path, 'checkpoints', 'model_best.pth') torch.save(self.state, best_fin_path) def load_checkpoint(self): # Load current checkpoint if exists fin_path = os.path.join(self.opt.out_path, 'checkpoints', self.opt.resume) if os.path.isfile(fin_path): print("=> loading checkpoint '{}'".format(fin_path)) checkpoint = torch.load(fin_path, map_location=lambda storage, loc: storage) self.epoch = checkpoint['epoch'] + 1 self.best_test_loss = checkpoint['best_test_loss'] self.model.load_state_dict(checkpoint['model_state_dict']) self.classifier.load_state_dict( checkpoint['classifier_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) #self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) print("=> loaded checkpoint '{}' (epoch {})".format( self.opt.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(self.opt.resume)) if os.path.isfile(self.visdom_log_file): self.vis.replay_log(log_filename=self.visdom_log_file)
class Model: def __init__(self, opt): self.opt = opt self.device = torch.device("cuda" if opt.ngpu else "cpu") self.model, self.classifier = models.get_model(opt.net_type, opt.loss_type, opt.pretrained, int(opt.nclasses)) self.model = self.model.to(self.device) self.classifier = self.classifier.to(self.device) if opt.ngpu>1: self.model = nn.DataParallel(self.model) self.loss = models.init_loss(opt.loss_type) self.loss = self.loss.to(self.device) self.optimizer = utils.get_optimizer(self.model, self.opt) self.lr_scheduler = utils.get_lr_scheduler(self.opt, self.optimizer) self.train_loader = datasets.generate_loader(opt,'train') self.test_loader = datasets.generate_loader(opt,'val') self.epoch = 0 self.best_epoch = False self.training = False self.state = {} self.train_loss = utils.AverageMeter() self.test_loss = utils.AverageMeter() self.batch_time = utils.AverageMeter() if self.opt.loss_type in ['cce', 'bce', 'mse', 'arc_margin']: self.test_metrics = utils.AverageMeter() else: self.test_metrics = utils.ROCMeter() self.best_test_loss = utils.AverageMeter() self.best_test_loss.update(np.array([np.inf])) self.visdom_log_file = os.path.join(self.opt.out_path, 'log_files', 'visdom.log') self.vis = Visdom(port = opt.visdom_port, log_to_filename=self.visdom_log_file, env=opt.exp_name + '_' + str(opt.fold)) self.vis_loss_opts = {'xlabel': 'epoch', 'ylabel': 'loss', 'title':'losses', 'legend': ['train_loss', 'val_loss']} self.vis_epochloss_opts = {'xlabel': 'epoch', 'ylabel': 'loss', 'title':'epoch_losses', 'legend': ['train_loss', 'val_loss']} def train(self): # Init Log file if self.opt.resume: self.log_msg('resuming...\n') # Continue training from checkpoint self.load_checkpoint() else: self.log_msg() for epoch in range(self.epoch, self.opt.num_epochs): self.epoch = epoch ''' if epoch < 0: for param in self.model.module.body.parameters(): param.requires_grad=False elif epoch == 0: for param in self.model.module.body.parameters(): param.requires_grad=True ''' self.lr_scheduler.step() self.train_epoch() self.test_epoch() self.log_epoch() self.vislog_epoch() self.create_state() self.save_state() def train_epoch(self): """ Trains model for 1 epoch """ self.model.train() self.classifier.train() self.training = True torch.set_grad_enabled(self.training) self.train_loss.reset() self.batch_time.reset() time_stamp = time.time() self.batch_idx = 0 for batch_idx, (data, target) in enumerate(self.train_loader): self.batch_idx = batch_idx data = data.to(self.device) target = target.to(self.device) self.optimizer.zero_grad() output = self.model(data) if isinstance(self.classifier, nn.Linear): output = self.classifier(output) else: output = self.classifier(output, target) if self.opt.loss_type == 'bce' or self.opt.loss_type == 'mse': target = target.float() loss_tensor = self.loss(output.squeeze(), target) else: loss_tensor = self.loss(output, target) loss_tensor.backward() self.optimizer.step() self.train_loss.update(loss_tensor.item()) self.batch_time.update(time.time() - time_stamp) time_stamp = time.time() self.log_batch(batch_idx) self.vislog_batch(batch_idx) if self.opt.debug and (batch_idx==10): print('Debugging done!') break; def test_epoch(self): """ Calculates loss and metrics for test set """ self.training = False torch.set_grad_enabled(self.training) self.model.eval() self.classifier.eval() self.batch_time.reset() self.test_loss.reset() self.test_metrics.reset() time_stamp = time.time() for batch_idx, (data, target) in enumerate(self.test_loader): data = data.to(self.device) target = target.to(self.device) output = self.model(data) output = self.classifier(output) if self.opt.loss_type == 'bce' or self.opt.loss_type == 'mse': target = target.float() loss_tensor = self.loss(output.squeeze(), target) else: loss_tensor = self.loss(output, target) self.test_loss.update(loss_tensor.item()) if self.opt.loss_type == 'cce': output = torch.nn.functional.softmax(output, dim=1) elif self.opt.loss_type.startswith('arc_margin'): output = torch.nn.functional.softmax(output, dim=1) elif self.opt.loss_type == 'bce': output = torch.sigmoid(output) metrics = self.calculate_metrics(output, target) self.test_metrics.update(metrics) self.batch_time.update(time.time() - time_stamp) time_stamp = time.time() self.log_batch(batch_idx) #self.vislog_batch(batch_idx) if self.opt.debug and (batch_idx==10): print('Debugging done!') break; self.best_epoch = self.test_loss.avg < self.best_test_loss.val if self.best_epoch: # self.best_test_loss.val is container for best loss, # n is not used in the calculation self.best_test_loss.update(self.test_loss.avg, n=0) def calculate_metrics(self, output, target): """ Calculates test metrix for given batch and its input """ batch_result = None t = target o = output if self.opt.loss_type == 'bce': binary_accuracy = (t.byte()==(o>0.5)).float().mean(0).cpu().numpy() batch_result = binary_accuracy elif self.opt.loss_type =='mse': mean_average_error = torch.abs(t-o.squeeze()).mean(0).cpu().numpy() batch_result = mean_average_error elif self.opt.loss_type == 'cce' or self.opt.loss_type == 'arc_margin': top1_accuracy = (torch.argmax(o, 1)==t).float().mean().item() batch_result = top1_accuracy else: raise Exception('This loss function is not implemented yet') return batch_result def log_batch(self, batch_idx): if batch_idx % self.opt.log_batch_interval == 0: cur_len = len(self.train_loader) if self.training else len(self.test_loader) cur_loss = self.train_loss if self.training else self.test_loss output_string = 'Train ' if self.training else 'Test ' output_string +='Epoch {}[{:.2f}%]: [{:.2f}({:.3f}) s]\t'.format(self.epoch, 100.* batch_idx/cur_len, self.batch_time.val,self.batch_time.avg) loss_i_string = 'Loss: {:.5f}({:.5f})\t'.format(cur_loss.val, cur_loss.avg) output_string += loss_i_string print(output_string) def vislog_batch(self, batch_idx): loader_len = len(self.train_loader) if self.training else len(self.test_loader) cur_loss = self.train_loss if self.training else self.test_loss loss_type = 'train_loss' if self.training else 'val_loss' x_value = self.epoch + batch_idx / loader_len y_value = cur_loss.val self.vis.line([y_value], [x_value], name=loss_type, win='losses', update='append') self.vis.update_window_opts(win='losses', opts=self.vis_loss_opts) def log_msg(self, msg=''): mode = 'a' if msg else 'w' f = open(os.path.join(self.opt.out_path, 'log_files', 'train_log.txt'), mode) f.write(msg) f.close() def log_epoch(self): """ Epoch results log string""" out_train = 'Train: ' out_test = 'Test: ' loss_i_string = 'Loss: {:.5f}\t'.format(self.train_loss.avg) out_train += loss_i_string loss_i_string = 'Loss: {:.5f}\t'.format(self.test_loss.avg) out_test += loss_i_string out_test+='\nTest: ' out_test+= '{0}\t{1:.4f}\t'.format(self.opt.loss_type, self.test_metrics.avg) is_best = 'Best ' if self.best_epoch else '' out_res = is_best+'Epoch {} results:\n'.format(self.epoch)+out_train+'\n'+out_test+'\n' print(out_res) self.log_msg(out_res) def vislog_epoch(self): x_value = self.epoch self.vis.line([self.train_loss.avg], [x_value], name='train_loss', win='epoch_losses', update='append') self.vis.line([self.test_loss.avg], [x_value], name='val_loss', win='epoch_losses', update='append') self.vis.update_window_opts(win='epoch_losses', opts=self.vis_epochloss_opts) ''' LEGACY CODE ''' ''' def adjust_lr(self): if self.opt.lr_type == 'step_lr': Set the LR to the initial LR decayed by lr_decay_lvl every lr_decay_period epochs lr = self.opt.lr * (self.opt.lr_decay_lvl ** ((self.epoch+1) // self.opt.lr_decay_period)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr elif self.opt.lr_type == 'cosine_lr': Cosine LR by [email protected] and [email protected] n_batches = len(self.train_loader) t_total = self.opt.num_epochs * n_batches t_cur = ((self.epoch) % self.opt.num_epochs) * n_batches t_cur += self.batch_idx lr_scale = 0.5 * (1 + math.cos(math.pi * t_cur / t_total)) lr_scale_prev = 0.5 * (1 + math.cos( math.pi * np.clip((t_cur - 1), 0, t_total) / t_total)) lr_scale_change = lr_scale / lr_scale_prev self.lr *= lr_scale_change if self.batch_idx % self.opt.log_batch_interval == 0 and self.batch_idx == 0: print (f'LR: {self.lr:.4f}') for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr else: raise Exception('Unexpected lr type') ''' def create_state(self): self.state = { # Params to be saved in checkpoint 'epoch' : self.epoch, 'model_state_dict' : self.model.state_dict(), 'classifier_state_dict': self.classifier.state_dict(), 'best_test_loss' : self.best_test_loss, 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict(), } def save_state(self): if self.opt.log_checkpoint == 0: self.save_checkpoint('checkpoint.pth') else: if (self.epoch % self.opt.log_checkpoint == 0): self.save_checkpoint('model_{}.pth'.format(self.epoch)) def save_checkpoint(self, filename): # Save model to task_name/checkpoints/filename.pth fin_path = os.path.join(self.opt.out_path,'checkpoints', filename) torch.save(self.state, fin_path) if self.best_epoch: best_fin_path = os.path.join(self.opt.out_path, 'checkpoints', 'model_best.pth') torch.save(self.state, best_fin_path) def load_checkpoint(self): # Load current checkpoint if exists fin_path = os.path.join(self.opt.out_path,'checkpoints',self.opt.resume) if os.path.isfile(fin_path): print("=> loading checkpoint '{}'".format(fin_path)) checkpoint = torch.load(fin_path, map_location=lambda storage, loc: storage) self.epoch = checkpoint['epoch'] + 1 self.best_test_loss = checkpoint['best_test_loss'] self.model.load_state_dict(checkpoint['model_state_dict']) self.classifier.load_state_dict(checkpoint['classifier_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) print("=> loaded checkpoint '{}' (epoch {})".format(self.opt.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(self.opt.resume)) if os.path.isfile(self.visdom_log_file): self.vis.replay_log(log_filename=self.visdom_log_file)