Exemplo n.º 1
0
def run() -> None:
    # Check if Redis enabled
    if not UserConf.redis_enabled:
        raise InitialisationException('Redis is not set up. Run the setup '
                                      'script to configure Redis.')

    logger = create_logger(InternalConf.redis_log_file, 'redis',
                           InternalConf.logging_level)

    print('Deleting all Redis keys.')

    # Redis database
    try:
        RedisApi(
            logger, InternalConf.redis_database, UserConf.redis_host,
            UserConf.redis_port, password=UserConf.redis_password,
            namespace=UserConf.unique_alerter_identifier
        ).delete_all_unsafe()
    except Exception as e:
        sys.exit(e)

    # Redis test database
    try:
        RedisApi(
            logger, InternalConf.redis_test_database, UserConf.redis_host,
            UserConf.redis_port, password=UserConf.redis_password,
            namespace=UserConf.unique_alerter_identifier
        ).delete_all_unsafe()
    except Exception as e:
        sys.exit(e)

    print('Done deleting all Redis keys.')
Exemplo n.º 2
0
def _initialise_logger(component_display_name: str, component_module_name: str,
                       log_file_template: str) -> logging.Logger:
    # Try initialising the logger until successful. This had to be done
    # separately to avoid instances when the logger creation failed and we
    # attempt to use it.
    while True:
        try:
            new_logger = create_logger(
                log_file_template.format(component_display_name),
                component_module_name,
                env.LOGGING_LEVEL,
                rotating=True)
            break
        except Exception as e:
            # Use a dummy logger in this case because we cannot create the
            # manager's logger.
            dummy_logger = logging.getLogger('DUMMY_LOGGER')
            log_and_print(
                get_initialisation_error_message(component_display_name, e),
                dummy_logger)
            log_and_print(get_reattempting_message(component_display_name),
                          dummy_logger)
            # sleep before trying again
            time.sleep(RE_INITIALISE_SLEEPING_PERIOD)

    return new_logger
Exemplo n.º 3
0
def run_monitor_nodes(node: Node):
    # Monitor name based on node
    monitor_name = 'Node monitor ({})'.format(node.name)

    # Logger initialisation
    logger_monitor_node = create_logger(
        InternalConf.node_monitor_general_log_file_template.format(node.name),
        node.name,
        InternalConf.logging_level,
        rotating=True)

    # Initialise monitor
    node_monitor = NodeMonitor(monitor_name, channel_set, logger_monitor_node,
                               REDIS, node)

    # Start
    log_and_print('{} started.'.format(monitor_name))
    sys.stdout.flush()
    try:
        start_node_monitor(node_monitor,
                           InternalConf.node_monitor_period_seconds,
                           logger_monitor_node)
    except Exception as e:
        channel_set.alert_error(TerminatedDueToExceptionAlert(monitor_name, e))
    log_and_print('{} stopped.'.format(monitor_name))
Exemplo n.º 4
0
def _get_log_channel(alerts_log_file: str, channel_name: str,
                     logger_general: logging.Logger,
                     internal_conf: InternalConfig = InternalConf) \
        -> LogChannel:
    # Logger initialisation
    logger_alerts = create_logger(alerts_log_file, 'alerts',
                                  internal_conf.logging_level)
    return LogChannel(channel_name, logger_general, logger_alerts)
Exemplo n.º 5
0
def run_monitor_nodes(node: Node):
    # Monitor name based on node
    monitor_name = 'Node monitor ({})'.format(node.name)
    try:
        # Logger initialisation
        logger_monitor_node = create_logger(
            InternalConf.node_monitor_general_log_file_template.format(
                node.name),
            node.name,
            InternalConf.logging_level,
            rotating=True)

        # Get the data sources which belong to the same chain and prioritise
        # them over the node itself as data sources for indirect node monitoring
        data_sources = [
            data_source for data_source in data_source_nodes if
            node.chain == data_source.chain and data_source.name != node.name
        ]
        if node in data_source_nodes:
            data_sources.append(node)

        # Do not start if there is no data source.
        if len(data_sources) == 0:
            log_and_print(
                'Indirect monitoring will be disabled for node {} because no '
                'data source for chain {} was given in the nodes config file.'
                ''.format(node.name, node.chain))

        # Initialise monitor
        node_monitor = NodeMonitor(
            monitor_name, full_channel_set, logger_monitor_node,
            InternalConf.node_monitor_max_catch_up_blocks, REDIS, node,
            archive_alerts_disabled_by_chain[node.chain], data_sources,
            UserConf.polkadot_api_endpoint)
    except Exception as e:
        msg = '!!! Error when initialising {}: {} !!!'.format(monitor_name, e)
        log_and_print(msg)
        raise InitialisationException(msg)

    while True:
        # Start
        log_and_print('{} started.'.format(monitor_name))
        sys.stdout.flush()
        try:
            start_node_monitor(node_monitor,
                               InternalConf.node_monitor_period_seconds,
                               logger_monitor_node)
        except (UnexpectedApiCallErrorException,
                UnexpectedApiErrorWhenReadingDataException,
                InvalidStashAccountAddressException) as e:
            full_channel_set.alert_error(
                TerminatedDueToFatalExceptionAlert(monitor_name, e))
            log_and_print('{} stopped.'.format(monitor_name))
            break
        except Exception as e:
            full_channel_set.alert_error(
                TerminatedDueToExceptionAlert(monitor_name, e))
        log_and_print('{} stopped.'.format(monitor_name))
Exemplo n.º 6
0
Arquivo: main.py Projeto: nuguziii/MI
    def test(self):
        logger, final_output_dir, result_dir = create_logger(
            self.config['log_dir'], self.config['description'], 'test')

        logger.info(pprint.pformat(self.config))

        # TODO: set data loader
        test_dataset = self.dataset(self.width,
                                    self.height,
                                    self.depth,
                                    self.config['data_dir'] + "\\test",
                                    self.config['data_dir'] + "\\label",
                                    aug=[])
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True)

        logger.info("=> loading model '{}'".format(self.config['model_dir']))
        model = torch.load(self.config['model_dir'])

        # TODO: add meter (metric)
        batch_time = AverageMeter()
        data_time = AverageMeter()

        end = time.time()
        model.eval()
        for idx, (image, label, shape_label) in enumerate(test_loader):
            data_time.update(time.time() - end)

            # TODO: test model
            output = model(image)

            # TODO: metric
            # ...

            # TODO: save result, visualize (optional)
            # ...

            batch_time.update(time.time() - end)
            end = time.time()

            msg = '[{0}/{1}]\t' \
                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                  'Speed {speed:.1f} samples/s\t' \
                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)'.format(
                idx + 1, len(test_loader), batch_time=batch_time,
                speed=image.size(0) / batch_time.val,
                data_time=data_time)
            logger.info(msg)

        msg = '[total]\t' \
              'Date {data_time.avg:.3f}s'.format(
            data_time=data_time)
        logger.info(msg)
Exemplo n.º 7
0
def run_monitor_network(network_nodes_tuple: Tuple[str, List[Node]]):
    # Get network and nodes
    network = network_nodes_tuple[0]
    nodes = network_nodes_tuple[1]

    # Monitor name based on network
    monitor_name = 'Network monitor ({})'.format(network)

    # Initialisation
    try:
        # Logger initialisation
        logger_monitor_network = create_logger(
            InternalConf.network_monitor_general_log_file_template.format(
                network),
            network,
            InternalConf.logging_level,
            rotating=True)

        # Organize as validators and full nodes
        validators = [n for n in nodes if n.is_validator]
        full_nodes = [n for n in nodes if not n.is_validator]

        # Do not start if not enough nodes
        if 0 in [len(validators), len(full_nodes)]:
            log_and_print('!!! Could not start {}. It must have at least 1 '
                          'validator and 1 full node!!!'.format(monitor_name))
            return

        # Initialise monitor
        network_monitor = NetworkMonitor(
            monitor_name, full_channel_set, logger_monitor_network,
            InternalConf.network_monitor_max_catch_up_blocks, REDIS,
            full_nodes, validators)
    except Exception as e:
        msg = '!!! Error when initialising {}: {} !!!'.format(monitor_name, e)
        log_and_print(msg)
        raise InitialisationException(msg)

    while True:
        # Start
        log_and_print('{} started with {} validator(s) and {} full node(s).'
                      ''.format(monitor_name, len(validators),
                                len(full_nodes)))
        sys.stdout.flush()
        try:
            start_network_monitor(network_monitor,
                                  InternalConf.network_monitor_period_seconds,
                                  logger_monitor_network)
        except Exception as e:
            full_channel_set.alert_error(
                TerminatedDueToExceptionAlert(monitor_name, e))
        log_and_print('{} stopped.'.format(monitor_name))
Exemplo n.º 8
0
def run_monitor_blockchain(blockchain_nodes_tuple: Tuple[str, List[Node]]):
    # Get blockchain and nodes
    blockchain_name = blockchain_nodes_tuple[0]
    data_sources = blockchain_nodes_tuple[1]

    # Monitor name based on blockchain
    monitor_name = 'Blockchain monitor ({})'.format(blockchain_name)

    # Initialisation
    try:
        # Logger initialisation
        logger_monitor_blockchain = create_logger(
            InternalConf.blockchain_monitor_general_log_file_template.format(
                blockchain_name),
            blockchain_name,
            InternalConf.logging_level,
            rotating=True)

        # Create blockchain object
        blockchain = Blockchain(blockchain_name, REDIS)

        # Initialise monitor
        blockchain_monitor = BlockchainMonitor(monitor_name, blockchain,
                                               full_channel_set,
                                               logger_monitor_blockchain,
                                               REDIS, data_sources,
                                               UserConf.polkadot_api_endpoint)
    except Exception as e:
        msg = '!!! Error when initialising {}: {} !!!'.format(monitor_name, e)
        log_and_print(msg)
        raise InitialisationException(msg)

    while True:
        # Start
        log_and_print('{} started'.format(monitor_name))
        sys.stdout.flush()
        try:
            start_blockchain_monitor(
                blockchain_monitor,
                InternalConf.blockchain_monitor_period_seconds,
                logger_monitor_blockchain)
        except (UnexpectedApiCallErrorException,
                UnexpectedApiErrorWhenReadingDataException) as e:
            full_channel_set.alert_error(
                TerminatedDueToFatalExceptionAlert(monitor_name, e))
            log_and_print('{} stopped.'.format(monitor_name))
            break
        except Exception as e:
            full_channel_set.alert_error(
                TerminatedDueToExceptionAlert(monitor_name, e))
        log_and_print('{} stopped.'.format(monitor_name))
Exemplo n.º 9
0
def _initialise_alerts_logger() -> logging.Logger:
    # Try initialising the logger until successful. This had to be done
    # separately to avoid instances when the logger creation failed and we
    # attempt to use it.
    while True:
        try:
            alerts_logger = create_logger(env.ALERTS_LOG_FILE, 'Alerts',
                                          env.LOGGING_LEVEL, True)
            break
        except Exception as e:
            msg = get_initialisation_error_message('Alerts Log File', e)
            # Use a dummy logger in this case because we cannot create the
            # logger.
            log_and_print(msg, logging.getLogger('DUMMY_LOGGER'))
            # sleep before trying again
            time.sleep(RE_INITIALISE_SLEEPING_PERIOD)

    return alerts_logger
Exemplo n.º 10
0
def _initialise_alert_router() -> Tuple[AlertRouter, logging.Logger]:
    display_name = ALERT_ROUTER_NAME

    # Try initialising the logger until successful. This had to be done
    # separately to avoid instances when the logger creation failed and we
    # attempt to use it.
    while True:
        try:
            alert_router_logger = create_logger(env.ALERT_ROUTER_LOG_FILE,
                                                AlertRouter.__name__,
                                                env.LOGGING_LEVEL,
                                                rotating=True)
            break
        except Exception as e:
            # Use a dummy logger in this case because we cannot create the
            # manager's logger.
            dummy_logger = logging.getLogger('DUMMY_LOGGER')
            log_and_print(get_initialisation_error_message(display_name, e),
                          dummy_logger)
            log_and_print(get_reattempting_message(display_name), dummy_logger)
            # sleep before trying again
            time.sleep(RE_INITIALISE_SLEEPING_PERIOD)

    rabbit_ip = env.RABBIT_IP
    redis_ip = env.REDIS_IP
    redis_db = env.REDIS_DB
    redis_port = env.REDIS_PORT
    unique_alerter_identifier = env.UNIQUE_ALERTER_IDENTIFIER

    while True:
        try:
            alert_router = AlertRouter(display_name, alert_router_logger,
                                       rabbit_ip, redis_ip, redis_db,
                                       redis_port, unique_alerter_identifier,
                                       env.ENABLE_CONSOLE_ALERTS,
                                       env.ENABLE_LOG_ALERTS)
            return alert_router, alert_router_logger
        except Exception as e:
            log_and_print(get_initialisation_error_message(display_name, e),
                          alert_router_logger)
            log_and_print(get_reattempting_message(display_name),
                          alert_router_logger)
            # sleep before trying again
            time.sleep(RE_INITIALISE_SLEEPING_PERIOD)
Exemplo n.º 11
0
def _initialise_store_logger(
        store_display_name: str, store_module_name: str) -> logging.Logger:
    # Try initialising the logger until successful. This had to be done
    # separately to avoid instances when the logger creation failed and we
    # attempt to use it.
    while True:
        try:
            store_logger = create_logger(
                env.DATA_STORE_LOG_FILE_TEMPLATE.format(store_display_name),
                store_module_name, env.LOGGING_LEVEL, rotating=True)
            break
        except Exception as e:
            msg = get_initialisation_error_message(store_display_name, e)
            # Use a dummy logger in this case because we cannot create the
            # transformer's logger.
            log_and_print(msg, logging.getLogger('DUMMY_LOGGER'))
            # sleep before trying again
            time.sleep(RE_INITIALISE_SLEEPING_PERIOD)

    return store_logger
Exemplo n.º 12
0
def run_monitor_systems(node: Node):
    # Monitor name based on node
    monitor_name = 'System monitor ({})'.format(node.name)

    try:
        # Logger initialisation
        logger_monitor_system = create_logger(
            InternalConf.system_monitor_general_log_file_template.format(
                node.name),
            node.name,
            InternalConf.logging_level,
            rotating=True)

        # Create a system from the node
        system = System(node.name, REDIS, node, InternalConf)

        # Initialize the SystemMonitor
        system_monitor = SystemMonitor(monitor_name, system, full_channel_set,
                                       logger_monitor_system, REDIS,
                                       node.prometheus_endpoint, InternalConf)
    except Exception as e:
        msg = '!!! Error when initialising {}: {} !!!'.format(monitor_name, e)
        log_and_print(msg)
        raise InitialisationException(msg)

    while True:
        # Start
        log_and_print('{} started.'.format(monitor_name))
        try:
            start_system_monitor(system_monitor,
                                 InternalConf.system_monitor_period_seconds,
                                 logger_monitor_system)
        except Exception as e:
            full_channel_set.alert_error(
                TerminatedDueToFatalExceptionAlert(monitor_name, e))
            log_and_print('{} stopped.'.format(monitor_name))
            break
        except Exception as e:
            full_channel_set.alert_error(
                TerminatedDueToExceptionAlert(monitor_name, e))
        log_and_print('{} stopped.'.format(monitor_name))
Exemplo n.º 13
0
def run() -> None:
    # Check if Mongo enabled
    if not UserConf.mongo_enabled:
        raise InitialisationException('Mongo is not set up. Run the setup '
                                      'script to configure Mongo.')

    logger = create_logger(InternalConf.mongo_log_file, 'mongo',
                           InternalConf.logging_level)

    db_name = UserConf.mongo_db_name
    print('Deleting "{}" database from MongoDB.'.format(db_name))

    # Attempt to delete database
    try:
        MongoApi(logger, UserConf.mongo_db_name, UserConf.mongo_host,
                 UserConf.mongo_port, UserConf.mongo_user,
                 UserConf.mongo_pass).drop_db()
    except Exception as e:
        sys.exit(e)

    print('Done deleting "{}" database from MongoDB.'.format(db_name))
Exemplo n.º 14
0
def run_monitor_github(repo_config: RepoConfig):
    # Monitor name based on repository
    monitor_name = 'GitHub monitor ({})'.format(repo_config.repo_name)

    # Initialisation
    try:
        # Logger initialisation
        logger_monitor_github = create_logger(
            InternalConf.github_monitor_general_log_file_template.format(
                repo_config.repo_page.replace('/', '_')),
            repo_config.repo_page,
            InternalConf.logging_level,
            rotating=True)

        # Get releases page
        releases_page = InternalConf.github_releases_template.format(
            repo_config.repo_page)

        # Initialise monitor
        github_monitor = GitHubMonitor(
            monitor_name, full_channel_set, logger_monitor_github, REDIS,
            repo_config.repo_name, releases_page,
            InternalConf.redis_github_releases_key_prefix)
    except Exception as e:
        msg = '!!! Error when initialising {}: {} !!!'.format(monitor_name, e)
        log_and_print(msg)
        raise InitialisationException(msg)

    while True:
        # Start
        log_and_print('{} started.'.format(monitor_name))
        sys.stdout.flush()
        try:
            start_github_monitor(github_monitor,
                                 InternalConf.github_monitor_period_seconds,
                                 logger_monitor_github)
        except Exception as e:
            full_channel_set.alert_error(
                TerminatedDueToExceptionAlert(monitor_name, e))
        log_and_print('{} stopped.'.format(monitor_name))
Exemplo n.º 15
0
        except Exception as e:
            periodic_alive_reminder_channel_set.alert_error(
                TerminatedDueToExceptionAlert(name, e))
        log_and_print('{} stopped.'.format(name))


if __name__ == '__main__':
    if not INTERNAL_CONFIG_FILE_FOUND:
        sys.exit('Config file {} is missing.'.format(INTERNAL_CONFIG_FILE))
    elif len(MISSING_USER_CONFIG_FILES) > 0:
        sys.exit('Config file {} is missing. Make sure that you run the setup '
                 'script (run_setup.py) before running the alerter.'
                 ''.format(MISSING_USER_CONFIG_FILES[0]))

    # Global loggers and polkadot data wrapper initialisation
    logger_redis = create_logger(InternalConf.redis_log_file, 'redis',
                                 InternalConf.logging_level)
    logger_mongo = create_logger(InternalConf.mongo_log_file, 'mongo',
                                 InternalConf.logging_level)
    logger_general = create_logger(InternalConf.general_log_file,
                                   'general',
                                   InternalConf.logging_level,
                                   rotating=True)
    logger_commands_telegram = create_logger(
        InternalConf.telegram_commands_general_log_file,
        'commands_telegram',
        InternalConf.logging_level,
        rotating=True)
    log_file_alerts = InternalConf.alerts_log_file
    polkadot_api_data_wrapper = \
        PolkadotApiWrapper(logger_general, UserConf.polkadot_api_endpoint)
Exemplo n.º 16
0
def get_full_channel_set(channel_name: str, logger_general: logging.Logger,
                         redis: Optional[RedisApi], alerts_log_file: str,
                         internal_conf: InternalConfig = InternalConf,
                         user_conf: UserConfig = UserConf) -> ChannelSet:
    # Logger initialisation
    logger_alerts = create_logger(alerts_log_file, 'alerts',
                                  internal_conf.logging_level)

    # Initialise list of channels with default channels
    channels = [
        ConsoleChannel(channel_name, logger_general),
        LogChannel(channel_name, logger_general, logger_alerts)
    ]

    # Initialise backup channel sets with default channels
    backup_channels_for_telegram = ChannelSet(channels)
    backup_channels_for_twilio = ChannelSet(channels)

    # Add telegram alerts to channel set
    if user_conf.telegram_alerts_enabled:
        telegram_bot = TelegramBotApi(user_conf.telegram_alerts_bot_token,
                                      user_conf.telegram_alerts_bot_chat_id)
        telegram_channel = TelegramChannel(channel_name, logger_general, redis,
                                           telegram_bot,
                                           backup_channels_for_telegram)
        channels.append(telegram_channel)
    else:
        telegram_channel = None

    # Add email alerts to channel set
    if user_conf.email_alerts_enabled:
        email = EmailSender(user_conf.email_smtp, user_conf.email_from)
        email_channel = EmailChannel(channel_name, logger_general,
                                     redis, email, user_conf.email_to)
        channels.append(email_channel)
    else:
        email_channel = None

    # Add twilio alerts to channel set
    if user_conf.twilio_alerts_enabled:
        twilio = TwilioApi(user_conf.twilio_account_sid,
                           user_conf.twilio_auth_token)
        twilio_channel = TwilioChannel(channel_name, logger_general, redis,
                                       twilio, user_conf.twilio_phone_number,
                                       user_conf.twilio_dial_numbers,
                                       internal_conf.twiml_instructions_url,
                                       internal_conf.redis_twilio_snooze_key,
                                       backup_channels_for_twilio)
        channels.append(twilio_channel)
    else:
        # noinspection PyUnusedLocal
        twilio_channel = None

    # Set up email channel as backup channel for telegram and twilio
    if email_channel is not None:
        backup_channels_for_telegram.add_channel(email_channel)
        backup_channels_for_twilio.add_channel(email_channel)

    # Set up telegram channel as backup channel for twilio
    if telegram_channel is not None:
        backup_channels_for_twilio.add_channel(telegram_channel)

    return ChannelSet(channels)
Exemplo n.º 17
0
Arquivo: main.py Projeto: nuguziii/MI
    def train(self):
        print(self.config['log_dir'])
        logger, final_output_dir, tb_log_dir = create_logger(
            self.config['log_dir'], self.config['description'], 'train')

        logger.info(pprint.pformat(self.config))

        writer = SummaryWriter(log_dir=tb_log_dir)

        # TODO: set model
        model = self.network(in_channels=1,
                             out_channels=2,
                             final_sigmoid=False)  # softmax
        logger.info(model)
        model = torch.nn.DataParallel(model,
                                      device_ids=self.config['gpus']).cuda()

        # TODO: set data loader
        train_dataset = self.dataset(self.width,
                                     self.height,
                                     self.depth,
                                     self.config['data_dir'] + "\\train",
                                     self.config['data_dir'] + "\\label",
                                     aug=[])
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True)

        best_perf = 0.0
        begin_epoch = 0

        checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

        # TODO: set optimizer
        # reference: https://pytorch.org/docs/stable/optim.html
        optimizer = optim.Adam(model.parameters(),
                               lr=self.config['lr'],
                               weight_decay=1e-4)

        if os.path.exists(checkpoint_file) and self.config['auto_resume']:
            logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
            checkpoint = torch.load(checkpoint_file)
            begin_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            last_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])

            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_file, checkpoint['epoch']))

        # TODO: set learning rate scheduler
        # reference: https://pytorch.org/docs/stable/optim.html
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

        for epoch in range(begin_epoch, self.config['epoch']):
            model.train()

            losses = AverageMeter()
            batch_time = AverageMeter()
            data_time = AverageMeter()

            lr_scheduler.step()

            end = time.time()
            for idx, batch in enumerate(train_loader):
                data_time.update(time.time() - end)

                image = batch['image'].type(torch.cuda.FloatTensor)
                label = batch['label'].type(torch.cuda.LongTensor)
                ''' class weight calculation
                true_class = np.round_(float(label.sum()) / label.reshape((-1)).size(0), 2)
                class_weights = torch.Tensor([true_class, 1 - true_class]).type(torch.cuda.FloatTensor)
                '''

                # TODO: set model, input, output and loss
                output = model(image)
                loss = self.loss(output, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses.update(loss.item(), image.size(0))
                batch_time.update(time.time() - end)
                end = time.time()
                '''FOR DEBUG'''
                output_temp = one_hot_to_index(output.detach().cpu().numpy())
                save_image_to_nib(
                    output_temp[0].astype(np.uint8).transpose(1, 2, 0),
                    final_output_dir, 'res')

                # TODO: add validation stage

                msg = 'Epoch: [{0}][{1}/{2}]\t' \
                      'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                      'Speed {speed:.1f} samples/s\t' \
                      'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                      'Loss {loss.val:.5f}'.format(
                    epoch + 1, idx + 1, len(train_loader), batch_time=batch_time,
                    speed=image.size(0) / batch_time.val,
                    data_time=data_time,
                    loss=losses)
                logger.info(msg)

            pref_indicator = 0  # TODO: set metric function
            if pref_indicator > best_perf:
                best_perf = pref_indicator
                best_model = True
                best_model_state_file = os.path.join(final_output_dir,
                                                     'best_model.pth')
                logger.info('=> saving best model state to {}'.format(
                    best_model_state_file))
            else:
                best_model = False

            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': self.config['description'],
                    'state_dict': model.state_dict(),
                    'best_state_dict': model.module.state_dict(),
                    'perf': pref_indicator,
                    'optimizer': optimizer.state_dict(),
                }, best_model, final_output_dir, model, epoch + 1)

            # TODO: can add other measure to tensorboard
            writer.add_scalar('loss', losses.avg, epoch + 1)

        final_model_state_file = os.path.join(final_output_dir,
                                              'final_state.pth')
        logger.info(
            '=> saving final model state to {}'.format(final_model_state_file))
        torch.save(model.module.state_dict(), final_model_state_file)
        writer.close()