Exemplo n.º 1
0
    def test_add_custom_scalars(self):
        with TemporaryDirectory() as tmp_dir:
            writer = SummaryWriter(tmp_dir)
            writer.add_custom_scalars = MagicMock()
            with summary_writer_context(writer):
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["a", "b"], category="cat", title="title"
                )
                with self.assertRaisesRegexp(
                    AssertionError, "Title \(title\) is already in category \(cat\)"
                ):
                    SummaryWriterContext.add_custom_scalars_multilinechart(
                        ["c", "d"], category="cat", title="title"
                    )
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["e", "f"], category="cat", title="title2"
                )
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["g", "h"], category="cat2", title="title"
                )

            SummaryWriterContext.add_custom_scalars(writer)
            writer.add_custom_scalars.assert_called_once_with(
                {
                    "cat": {
                        "title": ["Multiline", ["a", "b"]],
                        "title2": ["Multiline", ["e", "f"]],
                    },
                    "cat2": {"title": ["Multiline", ["g", "h"]]},
                }
            )
Exemplo n.º 2
0
    def test_add_custom_scalars(self):
        with TemporaryDirectory() as tmp_dir:
            writer = SummaryWriter(tmp_dir)
            writer.add_custom_scalars = MagicMock()
            with summary_writer_context(writer):
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["a", "b"], category="cat", title="title")
                with self.assertRaisesRegexp(
                        AssertionError,
                        "Title \(title\) is already in category \(cat\)"):
                    SummaryWriterContext.add_custom_scalars_multilinechart(
                        ["c", "d"], category="cat", title="title")
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["e", "f"], category="cat", title="title2")
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["g", "h"], category="cat2", title="title")

            SummaryWriterContext.add_custom_scalars(writer)
            writer.add_custom_scalars.assert_called_once_with({
                "cat": {
                    "title": ["Multiline", ["a", "b"]],
                    "title2": ["Multiline", ["e", "f"]],
                },
                "cat2": {
                    "title": ["Multiline", ["g", "h"]]
                },
            })
Exemplo n.º 3
0
def add_custom_scalars(writer: SummaryWriter):
    writer.add_custom_scalars({
        "Predicate classification": {
            "BCE": ["MultiLine", ["(train|val)_gt/loss/bce"]],
            "Rank": ["MultiLine", ["(train|val)_gt/loss/rank"]],
            "Recall@5": ["MultiLine", ["(train|val)_gt/pc/recall_at_5"]],
            "Recall@10": ["MultiLine", ["(train|val)_gt/pc/recall_at_10"]],
            "Mean Average Precision": ["MultiLine", ["(train|val)_gt/pc/mAP"]],
        },
        "Visual relations detection metrics": {
            "Predicate Recall@50": [
                "MultiLine",
                ["(val|test)_gt/vr/predicate/recall_at_50"],
            ],
            "Predicate Recall@100": [
                "MultiLine",
                ["(val|test)_gt/vr/predicate/recall_at_100"],
            ],
            "Phrase Recall@50": [
                "MultiLine",
                ["(val|test)_d2/vr/phrase/recall_at_50"],
            ],
            "Phrase Recall@100": [
                "MultiLine",
                ["(val|test)_d2/vr/phrase/recall_at_100"],
            ],
            "Relationship Recall@50": [
                "MultiLine",
                ["(val|test)_d2/vr/relationship/recall_at_50"],
            ],
            "Relationship Recall@100": [
                "MultiLine",
                ["(val|test)_d2/vr/relationship/recall_at_100"],
            ],
        },
        "Others": {
            "GPU (MB)": ["MultiLine", ["(train|val|test)_(gt|vr)/gpu_mb"]]
        },
    })
Exemplo n.º 4
0
def main():

    resume = True
    path = 'data/NYU_DEPTH'
    batch_size = 16
    epochs = 10000
    device = torch.device('cuda:0')
    print_every = 5
    # exp_name = 'resnet18_nodropout_new'
    exp_name = 'only_depth'
    # exp_name = 'normal_internel'
    # exp_name = 'sep'
    lr = 1e-5
    weight_decay = 0.0005
    log_dir = os.path.join('logs', exp_name)
    model_dir = os.path.join('checkpoints', exp_name)
    val_every = 16
    save_every = 16


    # tensorboard
    # remove old log is not to resume
    if not resume:
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir)
            os.makedirs(log_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    tb = SummaryWriter(log_dir)
    tb.add_custom_scalars({
        'metrics': {
            'thres_1.25': ['Multiline', ['thres_1.25/train', 'thres_1.25/test']],
            'thres_1.25_2': ['Multiline', ['thres_1.25_2/train', 'thres_1.25_2/test']],
            'thres_1.25_3': ['Multiline', ['thres_1.25_3/train', 'thres_1.25_3/test']],
            'ard': ['Multiline', ['ard/train', 'ard/test']],
            'srd': ['Multiline', ['srd/train', 'srd/test']],
            'rmse_linear': ['Multiline', ['rmse_linear/train', 'rmse_linear/test']],
            'rmse_log': ['Multiline', ['rmse_log/train', 'rmse_log/test']],
            'rmse_log_invariant': ['Multiline', ['rmse_log_invariant/train', 'rmse_log_invariant/test']],
        }
    })
    
    
    # data loader
    dataset = NYUDepth(path, 'train')
    dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
    
    dataset_test = NYUDepth(path, 'test')
    dataloader_test = DataLoader(dataset_test, batch_size, shuffle=True, num_workers=4)
    
    
    # load model
    model = FCRN(True)
    model = model.to(device)
    
    
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    start_epoch = 0
    if resume:
        model_path = os.path.join(model_dir, 'model.pth')
        if os.path.exists(model_path):
            print('Loading checkpoint from {}...'.format(model_path))
            # load model and optimizer
            checkpoint = torch.load(os.path.join(model_dir, 'model.pth'), map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            print('Model loaded.')
        else:
            print('No checkpoint found. Train from scratch')
    
    # training
    metric_logger = MetricLogger()
    
    end = time.perf_counter()
    max_iters = epochs * len(dataloader)
    
    def normal_loss(pred, normal, conf):
        """
        :param pred: (B, 3, H, W)
        :param normal: (B, 3, H, W)
        :param conf: 1
        """
        dot_prod = (pred * normal).sum(dim=1)
        # weighted loss, (B, )
        batch_loss = ((1 - dot_prod) * conf[:, 0]).sum(1).sum(1)
        # normalize, to (B, )
        batch_loss /= conf[:, 0].sum(1).sum(1)
        return batch_loss.mean()

    def consistency_loss(pred, cloud, normal, conf):
        """
        :param pred: (B, 1, H, W)
        :param normal: (B, 3, H, W)
        :param cloud: (B, 3, H, W)
        :param conf: (B, 1, H, W)
        """
        B, _, _, _ = normal.size()
        normal = normal.detach()
        cloud = cloud.clone()
        cloud[:, 2:3, :, :] = pred
        # algorithm: use a kernel
        kernel = torch.ones((1, 1, 7, 7), device=pred.device)
        kernel = -kernel
        kernel[0, 0, 3, 3] = 48
    
        cloud_0 = cloud[:, 0:1]
        cloud_1 = cloud[:, 1:2]
        cloud_2 = cloud[:, 2:3]
        diff_0 = F.conv2d(cloud_0, kernel, padding=6, dilation=2)
        diff_1 = F.conv2d(cloud_1, kernel, padding=6, dilation=2)
        diff_2 = F.conv2d(cloud_2, kernel, padding=6, dilation=2)
        # (B, 3, H, W)
        diff = torch.cat((diff_0, diff_1, diff_2), dim=1)
        # normalize
        diff = F.normalize(diff, dim=1)
        # (B, 1, H, W)
        dot_prod = (diff * normal).sum(dim=1, keepdim=True)
        # weighted mean over image
        dot_prod = torch.abs(dot_prod.view(B, -1))
        conf = conf.view(B, -1)
        loss = (dot_prod * conf).sum(1) / conf.sum(1)
        # mean over batch
        return loss.mean()
    
    def criterion(depth_pred, normal_pred, depth, normal, cloud, conf):
        mse_loss = F.mse_loss(depth_pred, depth)
        consis_loss = consistency_loss(depth_pred, cloud, normal_pred, conf)
        norm_loss = normal_loss(normal_pred, normal, conf)
        consis_loss = torch.zeros_like(norm_loss)
        
        return mse_loss, mse_loss, mse_loss
        # return mse_loss, consis_loss, norm_loss
        # return norm_loss, norm_loss, norm_loss
    
    print('Start training')
    for epoch in range(start_epoch, epochs):
        # train
        model.train()
        for i, data in enumerate(dataloader):
            start = end
            i += 1
            data = [x.to(device) for x in data]
            image, depth, normal, conf, cloud = data
            depth_pred, normal_pred = model(image)
            mse_loss, consis_loss, norm_loss = criterion(depth_pred, normal_pred, depth, normal, cloud, conf)
            loss = mse_loss + consis_loss + norm_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # bookkeeping
            end = time.perf_counter()
            metric_logger.update(loss=loss.item())
            metric_logger.update(mse_loss=mse_loss.item())
            metric_logger.update(norm_loss=norm_loss.item())
            metric_logger.update(consis_loss=consis_loss.item())
            metric_logger.update(batch_time=end-start)

            
            if i % print_every == 0:
                # Compute eta. global step: starting from 1
                global_step = epoch * len(dataloader) + i
                seconds = (max_iters - global_step) * metric_logger['batch_time'].global_avg
                eta = datetime.timedelta(seconds=int(seconds))
                # to display: eta, epoch, iteration, loss, batch_time
                display_dict = {
                    'eta': eta,
                    'epoch': epoch,
                    'iter': i,
                    'loss': metric_logger['loss'].median,
                    'batch_time': metric_logger['batch_time'].median
                }
                display_str = [
                    'eta: {eta}s',
                    'epoch: {epoch}',
                    'iter: {iter}',
                    'loss: {loss:.4f}',
                    'batch_time: {batch_time:.4f}s',
                ]
                print(', '.join(display_str).format(**display_dict))
                
                # tensorboard
                min_depth = depth[0].min()
                max_depth = depth[0].max() * 1.25
                depth = (depth[0] - min_depth) / (max_depth - min_depth)
                depth_pred = (depth_pred[0] - min_depth) / (max_depth - min_depth)
                depth_pred = torch.clamp(depth_pred, min=0.0, max=1.0)
                normal = (normal[0] + 1) / 2
                normal_pred = (normal_pred[0] + 1) / 2
                conf = conf[0]
                
                tb.add_scalar('train/loss', metric_logger['loss'].median, global_step)
                tb.add_scalar('train/mse_loss', metric_logger['mse_loss'].median, global_step)
                tb.add_scalar('train/consis_loss', metric_logger['consis_loss'].median, global_step)
                tb.add_scalar('train/norm_loss', metric_logger['norm_loss'].median, global_step)
                
                tb.add_image('train/depth', depth, global_step)
                tb.add_image('train/normal', normal, global_step)
                tb.add_image('train/depth_pred', depth_pred, global_step)
                tb.add_image('train/normal_pred', normal_pred, global_step)
                tb.add_image('train/conf', conf, global_step)
                tb.add_image('train/image', image[0], global_step)
                
        if (epoch) % val_every == 0 and epoch != 0:
            # validate after each epoch
            validate(dataloader, model, device, tb, epoch, 'train')
            validate(dataloader_test, model, device, tb, epoch, 'test')
        if (epoch) % save_every == 0 and epoch != 0:
            to_save = {
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
                'epoch': epoch,
            }
            torch.save(to_save, os.path.join(model_dir, 'model.pth'))
        if i > 0 and i % 20 == 0:
            #             logger.info('[%d, %5d] loss: %.3f' %
            #                   (epoch + 1, i + 1, running_loss / 2000))
            plot(epoch * len(trainloader) + i, running_loss, 'Train Loss')
            running_loss = 0.0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PPDL')
    parser.add_argument('--params', dest='params', default='utils/params.yaml')
    parser.add_argument('--name', dest='name', required=True)

    args = parser.parse_args()
    d = datetime.now().strftime('%b.%d_%H.%M.%S')
    writer = SummaryWriter(log_dir=f'runs/{args.name}')
    writer.add_custom_scalars(layout)

    with open(args.params) as f:
        params = yaml.load(f)
    if params.get('model', False) == 'word':
        helper = TextHelper(current_time=d, params=params, name='text')

        helper.corpus = torch.load(helper.params['corpus'])
        logger.info(helper.corpus.train.shape)
    else:
        helper = ImageHelper(current_time=d, params=params, name='utk')
    logger.addHandler(
        logging.FileHandler(filename=f'{helper.folder_path}/log.txt'))
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.DEBUG)
    logger.info(f'current path: {helper.folder_path}')
Exemplo n.º 6
0
def main(args):
    config_module = importlib.import_module(args.config)
    config = config_module.Config()

    # Load experiment setting
    config.initialize(args)
    max_iter = config.max_iter

    # Dataloader

    train_content_loader = get_dataloader(config, 'train')
    train_class_loader = get_dataloader(config, 'train')
    test_content_loader = get_dataloader(config, 'test')
    test_class_loader = get_dataloader(config, 'test', shuffle=True)
    trainfull_content_loader = get_dataloader(config,
                                              'trainfull',
                                              shuffle=True)
    trainfull_class_loader = get_dataloader(config, 'trainfull', shuffle=True)
    test_rec_loader = get_dataloader(config, 'test', shuffle=True)
    rec_loader = cycle(test_rec_loader)

    # Trainer
    trainer = Trainer(config)
    print("here!")

    tr_info = open(os.path.join(config.info_dir, "info-network"), "w")
    print(trainer.model, file=tr_info)
    tr_info.close()

    trainer.to(config.device)
    iterations = trainer.resume()

    # Summary Writer
    train_writer = SummaryWriter(os.path.join(config.tb_dir, 'train'))
    test_writer = SummaryWriter(os.path.join(config.tb_dir, 'test'))

    layout = {
        'adversarial acc & loss': {
            'acc': ['Multiline', ['gen_acc_all', 'dis_acc_all']],
            'adv_loss': ['Multiline', ['gen_loss_adv', 'dis_loss_adv_all']]
        },
        'reconstruction loss': {
            'gen_loss_recon_all': ['Multiline', ['gen_loss_recon_all']],
            'gen_loss_recon_r': ['Multiline', ['gen_loss_recon_r']],
            'gen_loss_recon_s': ['Multiline', ['gen_loss_recon_s']],
            'gen_loss_recon_u': ['Multiline', ['gen_loss_recon_u']]
        }
    }
    train_writer.add_custom_scalars(layout)

    it = iterations
    cyc_train_content_loader = cycle(train_content_loader)
    cyc_train_class_loader = cycle(train_class_loader)

    while True:
        it = it + 1
        co_data = next(cyc_train_content_loader)
        cl_data = next(cyc_train_class_loader)

        d_acc = trainer.dis_update(co_data, cl_data)
        g_acc = trainer.gen_update(co_data, cl_data)

        if (iterations + 1) % config.log_freq == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)
            rec_data = next(rec_loader)
            loss_dict, _ = trainer.test_rec(rec_data)
            for key, value in loss_dict.items():
                test_writer.add_scalar(key, value, iterations + 1)

        if ((iterations + 1) % config.mt_save_iter == 0
                or (iterations + 1) % config.mt_display_iter == 0):
            if (iterations + 1) % config.mt_save_iter == 0:
                key_str = '%08d' % (iterations + 1)
            else:
                key_str = 'current'

            with torch.no_grad():
                """latent codes"""  # !!!!! TD: add a separate function, merge with plot_clusters

                vis_dicts = {}
                for phase, co_loader, cl_loader, writer in [[
                        'train', train_content_loader, train_class_loader,
                        train_writer
                ], [
                        'test', test_content_loader, test_class_loader,
                        test_writer
                ]]:

                    vis_dict = None
                    for t, tcl_data in enumerate(cl_loader):
                        vis_codes = trainer.get_latent_codes(tcl_data)
                        if vis_dict is None:
                            vis_dict = {}
                            for key, value in vis_codes.items():
                                vis_dict[key] = [value]
                        else:
                            for key, value in vis_codes.items():
                                vis_dict[key].append(value)
                    for key, value in vis_dict.items():
                        if phase == "test" and key == "content_code":
                            continue
                        if key == "meta":
                            secondary_keys = value[0].keys()
                            num = len(value)
                            vis_dict[key] = {
                                secondary_key: [
                                    to_float(item) for i in range(num)
                                    for item in value[i][secondary_key]
                                ]
                                for secondary_key in secondary_keys
                            }
                        else:
                            vis_dict[key] = torch.cat(vis_dict[key], 0)
                            vis_dict[key] = vis_dict[key].cpu().numpy()
                            vis_dict[key] = to_float(vis_dict[key].reshape(
                                vis_dict[key].shape[0], -1))

                    vis_dicts[phase] = vis_dict

                writers = {"train": train_writer, "test": test_writer}
                get_all_plots(vis_dicts,
                              os.path.join(config.output_dir, key_str),
                              writers, iterations + 1)
                """outputs"""
                for phase, co_loader, cl_loader in [[
                        'trainfull', trainfull_content_loader,
                        trainfull_class_loader
                ], ['test', test_content_loader, test_class_loader]]:
                    for status in ["3d", "2d"]:
                        name = "%s_%s_%s" % (phase, key_str, status)
                        outputs = {}
                        for t, (tco_data, tcl_data) in enumerate(
                                zip(co_loader, cl_loader)):
                            if t >= config.test_batch_n:
                                break
                            cur_outputs = trainer.test(tco_data, tcl_data,
                                                       status)
                            for key in cur_outputs.keys():
                                output = cur_outputs[key]
                                if key not in outputs:
                                    outputs[key] = []
                                if isinstance(output, torch.Tensor):
                                    outputs[key].append(
                                        output.reshape(output.shape[1:]))
                                else:
                                    outputs[key].append(output)

                        output_path = os.path.join(config.output_dir, name)
                        print("%s saved" % name)
                        torch.save(outputs, output_path)

        if (iterations + 1) % config.save_freq == 0:
            trainer.save(iterations)
            print('Saved model at iteration %d' % (iterations + 1))

        iterations += 1
        if iterations >= max_iter:
            print("Finish Training")
            sys.exit(0)
Exemplo n.º 7
0
class Summarizer(object):
    def __init__(self):
        self.report = False
        self.global_step = None
        self.writer = None

    def initialize_writer(self, log_dir):
        self.writer = SummaryWriter(log_dir)

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_scalar(tag,
                               scalar_value,
                               global_step=global_step,
                               walltime=walltime)

    def add_scalars(self,
                    main_tag,
                    tag_scalar_dict,
                    global_step=None,
                    walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_scalars(self,
                                main_tag,
                                tag_scalar_dict,
                                global_step=global_step,
                                walltime=walltime)

    def add_histogram(self,
                      tag,
                      values,
                      global_step=None,
                      bins='tensorflow',
                      walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        if isinstance(values, chainer.cuda.cupy.ndarray):
            values = chainer.cuda.to_cpu(values)

        self.writer.add_histogram(tag,
                                  values,
                                  global_step=global_step,
                                  bins=bins,
                                  walltime=walltime)

    def add_image(self, tag, img_tensor, global_step=None, walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_image(tag,
                              img_tensor,
                              global_step=global_step,
                              walltime=walltime)

    def add_image_with_boxes(self,
                             tag,
                             img_tensor,
                             box_tensor,
                             global_step=None,
                             walltime=None,
                             **kwargs):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_image_with_boxes(tag,
                                         img_tensor,
                                         box_tensor,
                                         global_step=global_step,
                                         walltime=walltime,
                                         **kwargs)

    def add_figure(self,
                   tag,
                   figure,
                   global_step=None,
                   close=True,
                   walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_figure(tag,
                               figure,
                               global_step=global_step,
                               close=close,
                               walltime=walltime)

    def add_video(self,
                  tag,
                  vid_tensor,
                  global_step=None,
                  fps=4,
                  walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_video(tag,
                              vid_tensor,
                              global_step=global_step,
                              fps=fps,
                              walltime=walltime)

    def add_audio(self,
                  tag,
                  snd_tensor,
                  global_step=None,
                  sample_rate=44100,
                  walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_audio(tag,
                              snd_tensor,
                              global_step=global_step,
                              sample_rate=sample_rate,
                              walltime=walltime)

    def add_text(self, tag, text_string, global_step=None, walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_text(tag,
                             text_string,
                             global_step=global_step,
                             walltime=walltime)

    def add_graph_onnx(self, prototxt):
        if not self.report:
            return

        self.writer.add_graph_onnx(self, prototxt)

    def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
        if not self.report:
            return

        self.writer.add_graph(model,
                              input_to_model=input_to_model,
                              verbose=verbose,
                              **kwargs)

    def add_embedding(self,
                      mat,
                      metadata=None,
                      label_img=None,
                      global_step=None,
                      tag='default',
                      metadata_header=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_embedding(mat,
                                  metadata=metadata,
                                  label_img=label_img,
                                  global_step=global_step,
                                  tag=tag,
                                  metadata_header=metadata_header)

    def add_pr_curve(self,
                     tag,
                     labels,
                     predictions,
                     global_step=None,
                     num_thresholds=127,
                     weights=None,
                     walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_pr_curve(tag,
                                 labels,
                                 predictions,
                                 global_step=global_step,
                                 num_thresholds=num_thresholds,
                                 weights=weights,
                                 walltime=walltime)

    def add_pr_curve_raw(self,
                         tag,
                         true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall,
                         global_step=None,
                         num_thresholds=127,
                         weights=None,
                         walltime=None):
        if not self.report:
            return

        if global_step is None and self.global_step is not None:
            global_step = self.global_step

        self.writer.add_pr_curve_raw(tag,
                                     true_positive_counts,
                                     false_positive_counts,
                                     true_negative_counts,
                                     false_negative_counts,
                                     precision,
                                     recall,
                                     global_step=global_step,
                                     num_thresholds=num_thresholds,
                                     weights=weights,
                                     walltime=walltime)

    def add_custom_scalars_multilinechart(self,
                                          tags,
                                          category='default',
                                          title='untitled'):
        if not self.report:
            return
        self.writer.add_custom_scalars_multilinechart(tags,
                                                      category=category,
                                                      title=title)

    def add_custom_scalars_marginchart(self,
                                       tags,
                                       category='default',
                                       title='untitled'):
        if not self.report:
            return
        self.writer.add_custom_scalars_marginchart(tags,
                                                   category=category,
                                                   title=title)

    def add_custom_scalars(self, layout):
        if not self.report:
            return
        self.writer.add_custom_scalars(layout)