Example #1
0
class VisdomLinePlotter(object):
    def __init__(self, env_name='main', logging_path=None):
        self.viz = Visdom(log_to_filename=logging_path)
        if os.path.isfile(logging_path):
            self.viz.replay_log(logging_path)
        self.env = env_name
        self.postfix = ''
        self.plots = {}

    def plot(self, var_name, split_name, title_name, x, y):
        title_name = '_'.join([title_name, self.postfix])
        if title_name not in self.plots:
            self.plots[title_name] = self.viz.line(X=np.array([x, x]),
                                                   Y=np.array([y, y]),
                                                   env=self.env,
                                                   opts=dict(
                                                       legend=[split_name],
                                                       title=title_name,
                                                       xlabel='iterations',
                                                       ylabel=var_name,
                                                   ))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.plots[title_name],
                          name=split_name,
                          update='append')

    def set_cv(self, cv):
        self.postfix = '_'.join(['cv', str(cv)])
Example #2
0
class VisdomLogger:
    def __init__(self, host='127.0.0.1', port=8097, env='main', log_path=None):
        from visdom import Visdom
        import json
        logger.info(f"using visdom on http://{host}:{port} env={env}")
        self.env = env
        self.viz = Visdom(server=f"http://{host}",
                          port=port,
                          env=env,
                          log_to_filename=log_path)
        self.windows = dict()
        # if prev log exists
        if log_path.exists():
            self.viz.replay_log(log_path)
            wins = json.loads(self.viz.get_window_data(win=None, env=env))
            for k, v in wins.items():
                names = [int(x['name']) for x in v['content']['data']]
                name = str(max(names) + 1)
                self.windows[v['title']] = {'win': v['id'], 'name': name}

    def add_plot(self, title, **kwargs):
        if title not in self.windows:
            self.windows[title] = {
                'win': None,
                'name': '1',
            }
        self.windows[title]['opts'] = {
            'title': title,
        }
        self.windows[title]['opts'].update(kwargs)

    def add_point(self, title, x, y):
        X, Y = torch.FloatTensor([
            x,
        ]), torch.FloatTensor([
            y,
        ])
        if title not in self.windows:
            self.add_plot(title)
        if self.windows[title]['win'] is None:
            w = self.viz.line(Y=Y,
                              X=X,
                              opts=self.windows[title]['opts'],
                              name=self.windows[title]['name'])
            self.windows[title]['win'] = w
        else:
            self.viz.line(Y=Y,
                          X=X,
                          update='append',
                          win=self.windows[title]['win'],
                          name=self.windows[title]['name'])
Example #3
0
class VisdomLogger:
    def __init__(self,
                 host='127.0.0.1',
                 port=8097,
                 env='main',
                 log_path=None,
                 rank=None):
        from visdom import Visdom
        logger.debug(f"using visdom on http://{host}:{port} env={env}")
        self.env = env
        self.rank = rank
        self.viz = Visdom(server=f"http://{host}",
                          port=port,
                          env=env,
                          log_to_filename=log_path)
        self.windows = dict()
        # if prev log exists
        if log_path is not None and log_path.exists() and (rank is None
                                                           or rank == 0):
            self.viz.replay_log(log_path)

    def _get_win(self, title, type):
        import json
        win_data = json.loads(self.viz.get_window_data(win=None, env=self.env))
        wins = [(w, v) for w, v in win_data.items()
                if v['title'] == title and v['type'] == type]
        if wins:
            handle, value = sorted(wins, key=lambda x: x[0])[0]
            return handle, value['content']
        else:
            return None, None

    def _get_rank0_win(self, title, type):
        if self.rank is not None and self.rank > 0:
            # wait and fetch the window handle until rank=0 client generates new window
            for _ in range(10):
                handle, content = self._get_win(title, type)
                if handle is not None:
                    return handle, content
                time.sleep(0.5)
            else:
                logger.error(
                    "couldn't get a proper window handle from the visdom server"
                )
                raise RuntimeError
        else:
            return self._get_win(title, type)

    def _new_window(self, cmd, title, **cmd_args):
        if cmd == self.viz.images:
            types = ("image", None)
        elif cmd == self.viz.scatter or cmd == self.viz.line:
            types = ("plot", "scatter")
        elif cmd == self.viz.heatmap:
            types = ("plot", "heatmap")
        else:
            types = ("plot", None)

        handle, content = self._get_rank0_win(title, types[0])

        if handle is None:
            if "opts" in cmd_args:
                cmd_args['opts'].update({
                    "title": title,
                })
            else:
                cmd_args['opts'] = {
                    "title": title,
                }
            if types == ("plot", "scatter"):
                name = f"1_{self.rank}" if self.rank is not None else "1"
                handle = cmd(name=name, **cmd_args)
            else:
                name = None
                handle = cmd(**cmd_args)
        else:
            if types == ("plot", "scatter"):
                name = max([
                    int(x['name'].partition('_')[0]) for x in content['data']
                ])
                name = f"{name+1}_{self.rank}" if self.rank is not None else f"{name+1}"
                cmd(win=handle, name=name, update="append", **cmd_args)
            else:
                name = None
                handle = cmd(win=handle, **cmd_args)
        self.windows[title] = {
            'handle': handle,
            'name': name,
            'opts': cmd_args["opts"],
        }

    def add_point(self, title, x, y, **kwargs):
        X, Y = torch.FloatTensor([
            x,
        ]), torch.FloatTensor([
            y,
        ])
        if title not in self.windows:
            cmd = self.viz.line
            self._new_window(cmd, title, X=X, Y=Y, opts=kwargs)
        else:
            self.windows[title]['opts'].update(kwargs)
            handle = self.windows[title]['handle']
            name = self.windows[title]['name']
            opts = self.windows[title]['opts']
            self.viz.line(win=handle,
                          update='append',
                          Y=Y,
                          X=X,
                          name=name,
                          opts=opts)

    def plot_heatmap(self, title, tensor, **kwargs):
        if title not in self.windows:
            cmd = self.viz.heatmap
            self._new_window(cmd, title, X=tensor, opts=kwargs)
        else:
            self.windows[title]['opts'].update(kwargs)
            handle = self.windows[title]['handle']
            opts = self.windows[title]['opts']
            self.viz.heatmap(win=handle, X=tensor, opts=opts)

    def plot_images(self, title, tensor, nrow, **kwargs):
        if title not in self.windows:
            cmd = self.viz.images
            self._new_window(cmd, title, tensor=tensor, nrow=nrow, opts=kwargs)
        else:
            self.windows[title]['opts'].update(kwargs)
            handle = self.windows[title]['handle']
            opts = self.windows[title]['opts']
            self.viz.images(win=handle, tensor=tensor, nrow=nrow, opts=opts)
Example #4
0
import argparse
from visdom import Visdom

parser = argparse.ArgumentParser(description='Visdom Log Writer.')

parser.add_argument('--visdom-url',
                    type=str,
                    required=True,
                    help='visdom URL for graphs, needs http://url')
parser.add_argument('--visdom-port',
                    type=int,
                    required=True,
                    help='visdom port for graphs')
parser.add_argument('--log-file',
                    type=str,
                    required=True,
                    help='the file to  (default: None)')
args = parser.parse_args()

if __name__ == "__main__":
    visdom = Visdom(server=args.visdom_url,
                    port=args.visdom_port,
                    use_incoming_socket=False,
                    raise_exceptions=False)
    visdom.replay_log(args.log_file)
Example #5
0
class Plot(object):
    def __init__(self, title="", env_name="", config=None, port=8080):
        self.env_name = env_name if env_name else title
        self.viz = Visdom(port=port, env=self.env_name)
        # self.viz.close()
        self.windows = {}
        self.title = title
        self.config = config

    def register_plot(self, name, xlabel, ylabel, plot_type="line", ymax=None):
        self.windows[name] = {"xlabel": xlabel, "ylabel": ylabel, "title": name, "plot_type": plot_type}
        self.windows[name]["opts"] = dict(title=name, markersize=5, xlabel=xlabel, ylabel=ylabel)

        if ymax is not None:
            self.windows[name]["opts"]["layoutopts"] = dict(plotly=dict(yaxis=dict(range=[0, ymax])))

    def update_plot(self, plot_name, x, y, **kwargs):
        # Create plot if not registered
        try:
            plot_d = self.windows[plot_name]
        except:
            warnings.warn("Plot not found, creating new plot")
            plot_d = {"xlabel": "X", "ylabel": "Y", "plot_type": "scatter"}

        plotter = self.viz.scatter if plot_d["plot_type"] == "scatter" else self.viz.line

        # WHY WAS "Y" A NESTED LIST???
        # data = {"X": np.asarray(x), "Y": np.asarray([y])} if plot_d["plot_type"] == "line" else {"X": np.asarray([x, y])}

        x = np.asarray(x)
        if len(x) < len(y):
            warnings.warn("X coords not found, interpolating")
            if x[0] == 0 and len(x) > 1:
                x[0] = x[1] - .001
            additional_x = np.linspace(0, x[0], len(y) - len(x))

            x = np.r_[additional_x, np.asarray(x)]
        data = {"X": x, "Y": np.asarray(y)} if plot_d["plot_type"] == "line" else {"X": np.asarray([x, y])}

        ## Update plot
        if "plot" in plot_d.keys():
            plotter(
                **data,
                win=plot_d["plot"],
                update="append"
            )
        else:  # Create new plot
            win = plotter(
                **data,
                opts=plot_d["opts"], **kwargs
            )
            plot_d["plot"] = win
            self.windows["name"] = plot_d

    # LOADING
    def load_log(self, path):
        self.viz.replay_log(path)

    def load_all_env(self, root, keyword="visdom"):
        for d, ss, fs in os.walk(root):
            for f in fs:
                full_env = os.path.join(d, f)
                # Don't load "BSF" graphs, just complete graphs
                if full_env[-5:] == ".json" and keyword in full_env and f != "losses.json" and "BSF_" not in full_env:
                    print("Loading {}".format(full_env))
                    self.viz.replay_log(full_env)  # viz.load load the environment to viz

    def save_env(self, file_path=None, current_env=None, new_env=None):
        if file_path is None:
            file_path = os.path.join(self.config["results_dir"], "visdom.json")
        if current_env is None:
            current_env = self.env_name

        new_env = current_env if new_env is None else new_env
        # self.viz = Visdom(env=current_env) # get current env
        data = json.loads(self.viz.get_window_data())
        if len(data) == 0:
            print("NOTHING HAS BEEN SAVED: NOTHING IN THIS VISDOM ENV - DOES IT EXIST ?")
            return

        file = open(file_path, 'w+')
        for datapoint in data.values():
            output = {
                'win': datapoint['id'],
                'eid': new_env,
                'opts': {}
            }

            if datapoint['type'] != "plot":
                output['data'] = [{'content': datapoint['content'], 'type': datapoint['type']}]
                if datapoint['height'] is not None:
                    output['opts']['height'] = datapoint['height']
                if datapoint['width'] is not None:
                    output['opts']['width'] = datapoint['width']
            else:
                output['data'] = datapoint['content']["data"]
                output['layout'] = datapoint['content']["layout"]

            to_write = json.dumps(["events", output])
            file.write(to_write + '\n')
        file.close()
from visdom import Visdom 
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--log_file', type=str, default='', help='log file')

opt = parser.parse_args()

viz = Visdom(port=8097)
viz.replay_log(opt.log_file)
Example #7
0
class Model:
    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device("cuda" if opt.ngpu else "cpu")

        self.model, self.classifier = models.get_model(opt.net_type,
                                                       opt.classifier_type,
                                                       opt.pretrained,
                                                       int(opt.nclasses))
        self.model = self.model.to(self.device)
        self.classifier = self.classifier.to(self.device)

        if opt.ngpu > 1:
            self.model = nn.DataParallel(self.model)

        self.loss = models.init_loss(opt.loss_type)
        self.loss = self.loss.to(self.device)

        self.optimizer = utils.get_optimizer(self.model, self.opt)
        self.lr_scheduler = utils.get_lr_scheduler(self.opt, self.optimizer)
        self.alpha_scheduler = utils.get_margin_alpha_scheduler(self.opt)

        self.train_loader = datasets.generate_loader(opt, 'train')
        self.test_loader = datasets.generate_loader(opt, 'val')

        self.epoch = 0
        self.best_epoch = False
        self.training = False
        self.state = {}

        self.train_loss = utils.AverageMeter()
        self.test_loss = utils.AverageMeter()
        self.batch_time = utils.AverageMeter()
        self.test_metrics = utils.ROCMeter()
        self.best_test_loss = utils.AverageMeter()
        self.best_test_loss.update(np.array([np.inf]))

        self.visdom_log_file = os.path.join(self.opt.out_path, 'log_files',
                                            'visdom.log')
        self.vis = Visdom(port=opt.visdom_port,
                          log_to_filename=self.visdom_log_file,
                          env=opt.exp_name + '_' + str(opt.fold))

        self.vis_loss_opts = {
            'xlabel': 'epoch',
            'ylabel': 'loss',
            'title': 'losses',
            'legend': ['train_loss', 'val_loss']
        }

        self.vis_tpr_opts = {
            'xlabel': 'epoch',
            'ylabel': 'tpr',
            'title': 'val_tpr',
            'legend': ['tpr@fpr10-2', 'tpr@fpr10-3', 'tpr@fpr10-4']
        }

        self.vis_epochloss_opts = {
            'xlabel': 'epoch',
            'ylabel': 'loss',
            'title': 'epoch_losses',
            'legend': ['train_loss', 'val_loss']
        }

    def train(self):

        # Init Log file
        if self.opt.resume:
            self.log_msg('resuming...\n')
            # Continue training from checkpoint
            self.load_checkpoint()
        else:
            self.log_msg()

        for epoch in range(self.epoch, self.opt.num_epochs):
            self.epoch = epoch

            #freezing model
            if self.opt.freeze_epoch:
                if epoch < self.opt.freeze_epoch:
                    if self.opt.ngpu > 1:
                        for param in self.model.module.parameters():
                            param.requires_grad = False
                    else:
                        for param in self.model.parameters():
                            param.requires_grad = False
                elif epoch == self.opt.freeze_epoch:
                    if self.opt.ngpu > 1:
                        for param in self.model.module.parameters():
                            param.requires_grad = True
                    else:
                        for param in self.model.parameters():
                            param.requires_grad = True

            self.lr_scheduler.step()
            self.train_epoch()
            self.test_epoch()
            self.log_epoch()
            self.vislog_epoch()
            self.create_state()
            self.save_state()

    def train_epoch(self):
        """
        Trains model for 1 epoch
        """
        self.model.train()
        self.classifier.train()
        self.training = True
        torch.set_grad_enabled(self.training)
        self.train_loss.reset()
        self.batch_time.reset()
        time_stamp = time.time()
        self.batch_idx = 0
        for batch_idx, (rgb_data, depth_data, ir_data,
                        target) in enumerate(self.train_loader):

            self.batch_idx = batch_idx
            rgb_data = rgb_data.to(self.device)
            depth_data = depth_data.to(self.device)
            ir_data = ir_data.to(self.device)
            target = target.to(self.device)

            self.optimizer.zero_grad()

            output = self.model(rgb_data, depth_data, ir_data)
            if isinstance(self.classifier, nn.Linear):
                output = self.classifier(output)
            else:
                if self.alpha_scheduler:
                    alpha = self.alpha_scheduler.get_alpha(self.epoch)
                    output = self.classifier(output, target, alpha=alpha)
                else:
                    output = self.classifier(output, target)

            if self.opt.loss_type == 'bce':
                target = target.float()
                loss_tensor = self.loss(output.squeeze(), target)
            else:
                loss_tensor = self.loss(output, target)

            loss_tensor.backward()

            self.optimizer.step()

            self.train_loss.update(loss_tensor.item())
            self.batch_time.update(time.time() - time_stamp)
            time_stamp = time.time()

            self.log_batch(batch_idx)
            self.vislog_batch(batch_idx)

    def test_epoch(self):
        """
        Calculates loss and metrics for test set
        """
        self.training = False
        torch.set_grad_enabled(self.training)
        self.model.eval()
        self.classifier.eval()

        self.batch_time.reset()
        self.test_loss.reset()
        self.test_metrics.reset()
        time_stamp = time.time()

        for batch_idx, (rgb_data, depth_data, ir_data,
                        target) in enumerate(self.test_loader):
            rgb_data = rgb_data.to(self.device)
            depth_data = depth_data.to(self.device)
            ir_data = ir_data.to(self.device)
            target = target.to(self.device)

            output = self.model(rgb_data, depth_data, ir_data)
            output = self.classifier(output)
            if self.opt.loss_type == 'bce':
                target = target.float()
                loss_tensor = self.loss(output.squeeze(), target)
            else:
                loss_tensor = self.loss(output, target)
            self.test_loss.update(loss_tensor.item())

            if self.opt.loss_type == 'cce' or self.opt.loss_type == 'focal_loss':
                output = torch.nn.functional.softmax(output, dim=1)
            elif self.opt.loss_type == 'bce':
                output = torch.sigmoid(output)

            self.test_metrics.update(target.cpu().numpy(),
                                     output.cpu().numpy())

            self.batch_time.update(time.time() - time_stamp)
            time_stamp = time.time()

            self.log_batch(batch_idx)
            #self.vislog_batch(batch_idx)
            if self.opt.debug and (batch_idx == 10):
                print('Debugging done!')
                break

        self.best_epoch = self.test_loss.avg < self.best_test_loss.val
        if self.best_epoch:
            # self.best_test_loss.val is container for best loss,
            # n is not used in the calculation
            self.best_test_loss.update(self.test_loss.avg, n=0)

    def calculate_metrics(self, output, target):
        """
        Calculates test metrix for given batch and its input
        """
        t = target
        o = output

        if self.opt.loss_type == 'bce':
            accuracy = (t.byte() == (o > 0.5)).float().mean(0).cpu().numpy()
            batch_result.append(binary_accuracy)

        elif self.opt.loss_type == 'cce':
            top1_accuracy = (torch.argmax(o, 1) == t).float().mean().item()
            batch_result.append(top1_accuracy)
        else:
            raise Exception('This loss function is not implemented yet')

        return batch_result

    def log_batch(self, batch_idx):
        if batch_idx % self.opt.log_batch_interval == 0:
            cur_len = len(self.train_loader) if self.training else len(
                self.test_loader)
            cur_loss = self.train_loss if self.training else self.test_loss

            output_string = 'Train ' if self.training else 'Test '
            output_string += 'Epoch {}[{:.2f}%]: [{:.2f}({:.3f}) s]\t'.format(
                self.epoch, 100. * batch_idx / cur_len, self.batch_time.val,
                self.batch_time.avg)

            loss_i_string = 'Loss: {:.5f}({:.5f})\t'.format(
                cur_loss.val, cur_loss.avg)
            output_string += loss_i_string

            if not self.training:
                output_string += '\n'

                metrics_i_string = 'Accuracy: {:.5f}\t'.format(
                    self.test_metrics.get_accuracy())
                output_string += metrics_i_string

            print(output_string)

    def vislog_batch(self, batch_idx):
        if batch_idx % self.opt.log_batch_interval == 0:
            loader_len = len(self.train_loader) if self.training else len(
                self.test_loader)
            cur_loss = self.train_loss if self.training else self.test_loss
            loss_type = 'train_loss' if self.training else 'val_loss'

            x_value = self.epoch + batch_idx / loader_len
            y_value = cur_loss.val
            self.vis.line([y_value], [x_value],
                          name=loss_type,
                          win='losses',
                          update='append')
            self.vis.update_window_opts(win='losses', opts=self.vis_loss_opts)

    def log_msg(self, msg=''):
        mode = 'a' if msg else 'w'
        f = open(os.path.join(self.opt.out_path, 'log_files', 'train_log.txt'),
                 mode)
        f.write(msg)
        f.close()

    def log_epoch(self):
        """ Epoch results log string"""
        out_train = 'Train: '
        out_test = 'Test:  '
        loss_i_string = 'Loss: {:.5f}\t'.format(self.train_loss.avg)
        out_train += loss_i_string
        loss_i_string = 'Loss: {:.5f}\t'.format(self.test_loss.avg)
        out_test += loss_i_string

        out_test += '\nTest:  '
        metrics_i_string = 'TPR@FPR=10-2: {:.4f}\t'.format(
            self.test_metrics.get_tpr(0.01))
        metrics_i_string += 'TPR@FPR=10-3: {:.4f}\t'.format(
            self.test_metrics.get_tpr(0.001))
        metrics_i_string += 'TPR@FPR=10-4: {:.4f}\t'.format(
            self.test_metrics.get_tpr(0.0001))
        out_test += metrics_i_string

        is_best = 'Best ' if self.best_epoch else ''
        out_res = is_best + 'Epoch {} results:\n'.format(
            self.epoch) + out_train + '\n' + out_test + '\n'

        print(out_res)
        self.log_msg(out_res)

    def vislog_epoch(self):
        x_value = self.epoch
        self.vis.line([self.train_loss.avg], [x_value],
                      name='train_loss',
                      win='epoch_losses',
                      update='append')
        self.vis.line([self.test_loss.avg], [x_value],
                      name='val_loss',
                      win='epoch_losses',
                      update='append')
        self.vis.update_window_opts(win='epoch_losses',
                                    opts=self.vis_epochloss_opts)

        self.vis.line([self.test_metrics.get_tpr(0.01)], [x_value],
                      name='tpr@fpr10-2',
                      win='val_tpr',
                      update='append')
        self.vis.line([self.test_metrics.get_tpr(0.001)], [x_value],
                      name='tpr@fpr10-3',
                      win='val_tpr',
                      update='append')
        self.vis.line([self.test_metrics.get_tpr(0.0001)], [x_value],
                      name='tpr@fpr10-4',
                      win='val_tpr',
                      update='append')
        self.vis.update_window_opts(win='val_tpr', opts=self.vis_tpr_opts)

    def create_state(self):
        self.state = {       # Params to be saved in checkpoint
                'epoch' : self.epoch,
                'model_state_dict' : self.model.state_dict(),
                'classifier_state_dict': self.classifier.state_dict(),
                'best_test_loss' : self.best_test_loss,
                'optimizer': self.optimizer.state_dict(),
                'lr_scheduler': self.lr_scheduler.state_dict(),
            }

    def save_state(self):
        if self.opt.log_checkpoint == 0:
            self.save_checkpoint('checkpoint.pth')
        else:
            if (self.epoch % self.opt.log_checkpoint == 0):
                self.save_checkpoint('model_{}.pth'.format(self.epoch))

    def save_checkpoint(
            self,
            filename):  # Save model to task_name/checkpoints/filename.pth
        fin_path = os.path.join(self.opt.out_path, 'checkpoints', filename)
        torch.save(self.state, fin_path)
        if self.best_epoch:
            best_fin_path = os.path.join(self.opt.out_path, 'checkpoints',
                                         'model_best.pth')
            torch.save(self.state, best_fin_path)

    def load_checkpoint(self):  # Load current checkpoint if exists
        fin_path = os.path.join(self.opt.out_path, 'checkpoints',
                                self.opt.resume)
        if os.path.isfile(fin_path):
            print("=> loading checkpoint '{}'".format(fin_path))
            checkpoint = torch.load(fin_path,
                                    map_location=lambda storage, loc: storage)
            self.epoch = checkpoint['epoch'] + 1
            self.best_test_loss = checkpoint['best_test_loss']
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.classifier.load_state_dict(
                checkpoint['classifier_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            #self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                self.opt.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(self.opt.resume))

        if os.path.isfile(self.visdom_log_file):
            self.vis.replay_log(log_filename=self.visdom_log_file)
class Model:
    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device("cuda" if opt.ngpu else "cpu")
        
        self.model, self.classifier = models.get_model(opt.net_type, 
                                                       opt.loss_type, 
                                                       opt.pretrained,
                                                       int(opt.nclasses))
        self.model = self.model.to(self.device)
        self.classifier = self.classifier.to(self.device)

        if opt.ngpu>1:
            self.model = nn.DataParallel(self.model)
            
        self.loss = models.init_loss(opt.loss_type)
        self.loss = self.loss.to(self.device)

        self.optimizer = utils.get_optimizer(self.model, self.opt)
        self.lr_scheduler = utils.get_lr_scheduler(self.opt, self.optimizer)

        self.train_loader = datasets.generate_loader(opt,'train') 
        self.test_loader = datasets.generate_loader(opt,'val')    
        
        self.epoch = 0
        self.best_epoch = False
        self.training = False
        self.state = {}
        

        self.train_loss = utils.AverageMeter()
        self.test_loss  = utils.AverageMeter()
        self.batch_time = utils.AverageMeter()
        if self.opt.loss_type in ['cce', 'bce', 'mse', 'arc_margin']:
            self.test_metrics = utils.AverageMeter()
        else:
            self.test_metrics = utils.ROCMeter()

        self.best_test_loss = utils.AverageMeter()                    
        self.best_test_loss.update(np.array([np.inf]))

        self.visdom_log_file = os.path.join(self.opt.out_path, 'log_files', 'visdom.log')
        self.vis = Visdom(port = opt.visdom_port,
                          log_to_filename=self.visdom_log_file,
                          env=opt.exp_name + '_' + str(opt.fold))

        self.vis_loss_opts = {'xlabel': 'epoch', 
                              'ylabel': 'loss', 
                              'title':'losses', 
                              'legend': ['train_loss', 'val_loss']}

        self.vis_epochloss_opts = {'xlabel': 'epoch', 
                              'ylabel': 'loss', 
                              'title':'epoch_losses', 
                              'legend': ['train_loss', 'val_loss']}

    def train(self):
        
        # Init Log file
        if self.opt.resume:
            self.log_msg('resuming...\n')
            # Continue training from checkpoint
            self.load_checkpoint()
        else:
             self.log_msg()


        for epoch in range(self.epoch, self.opt.num_epochs):
            self.epoch = epoch
            
            '''
            if epoch < 0:
                for param in self.model.module.body.parameters():
                    param.requires_grad=False
            elif epoch == 0:
                for param in self.model.module.body.parameters():
                    param.requires_grad=True
            '''

            self.lr_scheduler.step()
            self.train_epoch()
            self.test_epoch()
            self.log_epoch()
            self.vislog_epoch()
            self.create_state()
            self.save_state()  
    
    def train_epoch(self):
        """
        Trains model for 1 epoch
        """
        self.model.train()
        self.classifier.train()
        self.training = True
        torch.set_grad_enabled(self.training)
        self.train_loss.reset()
        self.batch_time.reset()
        time_stamp = time.time()
        self.batch_idx = 0
        for batch_idx, (data, target) in enumerate(self.train_loader):
            
            self.batch_idx = batch_idx
            data = data.to(self.device)
            target = target.to(self.device)

            self.optimizer.zero_grad()
            
            output = self.model(data)

            if isinstance(self.classifier, nn.Linear):
                output = self.classifier(output)
            else:
                output = self.classifier(output, target)

            if self.opt.loss_type == 'bce' or self.opt.loss_type == 'mse':
                target = target.float()
                loss_tensor = self.loss(output.squeeze(), target)
            else:
                loss_tensor = self.loss(output, target)

            loss_tensor.backward()   

            self.optimizer.step()

            self.train_loss.update(loss_tensor.item())
            self.batch_time.update(time.time() - time_stamp)
            time_stamp = time.time()
            
            self.log_batch(batch_idx)
            self.vislog_batch(batch_idx)
            if self.opt.debug and (batch_idx==10):
                print('Debugging done!')
                break;
            
    def test_epoch(self):
        """
        Calculates loss and metrics for test set
        """
        self.training = False
        torch.set_grad_enabled(self.training)
        self.model.eval()
        self.classifier.eval()
        
        self.batch_time.reset()
        self.test_loss.reset()
        self.test_metrics.reset()
        time_stamp = time.time()
        
        for batch_idx, (data, target) in enumerate(self.test_loader):
            data = data.to(self.device)
            target = target.to(self.device)

            output = self.model(data)
            
            output = self.classifier(output)
            if self.opt.loss_type == 'bce' or self.opt.loss_type == 'mse':
                target = target.float()
                loss_tensor = self.loss(output.squeeze(), target)
            else:
                loss_tensor = self.loss(output, target)
            self.test_loss.update(loss_tensor.item())

            if self.opt.loss_type == 'cce':
                output = torch.nn.functional.softmax(output, dim=1)
            elif self.opt.loss_type.startswith('arc_margin'):
                output = torch.nn.functional.softmax(output, dim=1)
            elif self.opt.loss_type == 'bce':
                output = torch.sigmoid(output)

            metrics = self.calculate_metrics(output, target)
            self.test_metrics.update(metrics)

            self.batch_time.update(time.time() - time_stamp)
            time_stamp = time.time()
            
            self.log_batch(batch_idx)
            #self.vislog_batch(batch_idx)
            if self.opt.debug and (batch_idx==10):
                print('Debugging done!')
                break;

        self.best_epoch = self.test_loss.avg < self.best_test_loss.val
        if self.best_epoch:
            # self.best_test_loss.val is container for best loss, 
            # n is not used in the calculation
            self.best_test_loss.update(self.test_loss.avg, n=0)
     
    def calculate_metrics(self, output, target):   
        """
        Calculates test metrix for given batch and its input
        """
        batch_result = None
        
        t = target
        o = output
            
        if self.opt.loss_type == 'bce':
            binary_accuracy = (t.byte()==(o>0.5)).float().mean(0).cpu().numpy()  
            batch_result = binary_accuracy
        elif self.opt.loss_type =='mse':
            mean_average_error = torch.abs(t-o.squeeze()).mean(0).cpu().numpy()
            batch_result = mean_average_error
        elif self.opt.loss_type == 'cce' or self.opt.loss_type == 'arc_margin':
            top1_accuracy = (torch.argmax(o, 1)==t).float().mean().item()
            batch_result = top1_accuracy
        else:
            raise Exception('This loss function is not implemented yet')
                
        return batch_result  

    
    def log_batch(self, batch_idx):
        if batch_idx % self.opt.log_batch_interval == 0:
            cur_len = len(self.train_loader) if self.training else len(self.test_loader)
            cur_loss = self.train_loss if self.training else self.test_loss
            
            output_string = 'Train ' if self.training else 'Test '
            output_string +='Epoch {}[{:.2f}%]: [{:.2f}({:.3f}) s]\t'.format(self.epoch,
                                                                          100.* batch_idx/cur_len, self.batch_time.val,self.batch_time.avg)
            
            loss_i_string = 'Loss: {:.5f}({:.5f})\t'.format(cur_loss.val, cur_loss.avg)
            output_string += loss_i_string
            
            print(output_string)
    
    def vislog_batch(self, batch_idx):
        loader_len = len(self.train_loader) if self.training else len(self.test_loader)
        cur_loss = self.train_loss if self.training else self.test_loss
        loss_type = 'train_loss' if self.training else 'val_loss'
        
        x_value = self.epoch + batch_idx / loader_len
        y_value = cur_loss.val
        self.vis.line([y_value], [x_value], 
                        name=loss_type, 
                        win='losses', 
                        update='append')
        self.vis.update_window_opts(win='losses', opts=self.vis_loss_opts)
    
    def log_msg(self, msg=''):
        mode = 'a' if msg else 'w'
        f = open(os.path.join(self.opt.out_path, 'log_files', 'train_log.txt'), mode)
        f.write(msg)
        f.close()
             
    def log_epoch(self):
        """ Epoch results log string"""
        out_train = 'Train: '
        out_test = 'Test:  '
        loss_i_string = 'Loss: {:.5f}\t'.format(self.train_loss.avg)
        out_train += loss_i_string
        loss_i_string = 'Loss: {:.5f}\t'.format(self.test_loss.avg)
        out_test += loss_i_string
            
        out_test+='\nTest:  '
        out_test+= '{0}\t{1:.4f}\t'.format(self.opt.loss_type, self.test_metrics.avg)
            
        is_best = 'Best ' if self.best_epoch else ''
        out_res = is_best+'Epoch {} results:\n'.format(self.epoch)+out_train+'\n'+out_test+'\n'
        
        print(out_res)
        self.log_msg(out_res)
        

    def vislog_epoch(self):
        x_value = self.epoch
        self.vis.line([self.train_loss.avg], [x_value], 
                        name='train_loss', 
                        win='epoch_losses', 
                        update='append')
        self.vis.line([self.test_loss.avg], [x_value], 
                        name='val_loss', 
                        win='epoch_losses', 
                        update='append')
        self.vis.update_window_opts(win='epoch_losses', opts=self.vis_epochloss_opts)


    ''' LEGACY CODE '''
    '''
    def adjust_lr(self):
        if self.opt.lr_type == 'step_lr':
            Set the LR to the initial LR decayed by lr_decay_lvl every lr_decay_period epochs
            lr = self.opt.lr * (self.opt.lr_decay_lvl ** ((self.epoch+1) // self.opt.lr_decay_period))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        elif self.opt.lr_type == 'cosine_lr':
            Cosine LR by [email protected] and [email protected]
            n_batches = len(self.train_loader)
            t_total = self.opt.num_epochs * n_batches
            t_cur = ((self.epoch) % self.opt.num_epochs) * n_batches
            t_cur += self.batch_idx
            lr_scale = 0.5 * (1 + math.cos(math.pi * t_cur / t_total))
            lr_scale_prev = 0.5 * (1 + math.cos(
                math.pi * np.clip((t_cur - 1), 0, t_total) / t_total))
            lr_scale_change = lr_scale / lr_scale_prev
            self.lr *= lr_scale_change
            if self.batch_idx % self.opt.log_batch_interval == 0 and self.batch_idx == 0:
                print (f'LR: {self.lr:.4f}')
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
        else:
            raise Exception('Unexpected lr type') 
    '''
                    
    def create_state(self):
        self.state = {       # Params to be saved in checkpoint
                'epoch' : self.epoch,
                'model_state_dict' : self.model.state_dict(),
                'classifier_state_dict': self.classifier.state_dict(),
                'best_test_loss' : self.best_test_loss,
                'optimizer': self.optimizer.state_dict(),
                'lr_scheduler': self.lr_scheduler.state_dict(),
            }
    
    def save_state(self):
        if self.opt.log_checkpoint == 0:
                self.save_checkpoint('checkpoint.pth')
        else:
            if (self.epoch % self.opt.log_checkpoint == 0):
                self.save_checkpoint('model_{}.pth'.format(self.epoch)) 
                  
    def save_checkpoint(self, filename):     # Save model to task_name/checkpoints/filename.pth
        fin_path = os.path.join(self.opt.out_path,'checkpoints', filename)
        torch.save(self.state, fin_path)
        if self.best_epoch:
            best_fin_path = os.path.join(self.opt.out_path, 'checkpoints', 'model_best.pth')
            torch.save(self.state, best_fin_path)
           

    def load_checkpoint(self):                            # Load current checkpoint if exists
        fin_path = os.path.join(self.opt.out_path,'checkpoints',self.opt.resume)
        if os.path.isfile(fin_path):
            print("=> loading checkpoint '{}'".format(fin_path))
            checkpoint = torch.load(fin_path, map_location=lambda storage, loc: storage)
            self.epoch = checkpoint['epoch'] + 1
            self.best_test_loss = checkpoint['best_test_loss']
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.classifier.load_state_dict(checkpoint['classifier_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

            print("=> loaded checkpoint '{}' (epoch {})".format(self.opt.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(self.opt.resume))

        if os.path.isfile(self.visdom_log_file):
                self.vis.replay_log(log_filename=self.visdom_log_file)