예제 #1
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')


    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)


    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    model.train()

    # Using Saliency
    # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2netp' + '/' + 'u2netp.pth'
    # u2_net = U2NETP(3,1)
    # u2_net.load_state_dict(torch.load(u2_model_dir))
    
    # if torch.cuda.is_available():
    #     u2_net.cuda()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
    elif data_name == 'COCO':
        make_data = coco_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']),
                          RandomMirror()])
    dataset = make_data(
        base_dir=_config['path'][data_name]['data_dir'],
        split=_config['path'][data_name]['data_split'],
        transforms=transforms,
        to_tensor=ToTensorNormalize(),
        labels=labels,
        max_iters=_config['n_steps'] * _config['batch_size'],
        n_ways=_config['task']['n_ways'],
        n_shots=_config['task']['n_shots'],
        n_queries=_config['task']['n_queries']
    )
    trainloader = DataLoader(
        dataset,
        batch_size=_config['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True
    )

    _log.info('###### Set optimizer ######')
    print(_config['mode'])
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0, 'dist_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [query_image.cuda()
                        for query_image in sample_batched['query_images']]
        query_labels = torch.cat(
            [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0)
        # Forward and Backward
        optimizer.zero_grad()
        
        # with torch.no_grad():
        #   # u2net
        #   inputs = query_images[0].type(torch.FloatTensor)
        #   labels = query_labels.type(torch.FloatTensor)
        #   if torch.cuda.is_available():
        #       inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
        #                                                                                   requires_grad=False)
        #   else:
        #       inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
        #   d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v)

        #   # normalization
        #   pred = d1[:,0,:,:]
        #   pred = normPRED(pred)
        pred = []

        query_pred, align_loss, dist_loss = model(support_images, support_fg_mask, support_bg_mask,
                                       query_images, pred)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + dist_loss + align_loss * 0.2 #_config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        dist_loss = dist_loss.detach().data.cpu().numpy() if dist_loss != 0 else 0
        align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        _run.log_scalar('dist_loss', dist_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss
        log_loss['dist_loss'] += dist_loss



        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, dist_loss: {dist_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(model.state_dict(),
                       os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(model.state_dict(),
               os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
예제 #2
0
def main(config):
    if not os.path.exists(config['snapshots']):
        os.makedirs(config['snapshots'])
    palette_path = config['palette_dir']
    with open(palette_path) as f:
        palette = f.readlines()
    palette = list(np.asarray([[int(p) for p in pal[0:-1].split(' ')] for pal in palette]).reshape(768))
    snap_shots_dir = config['snapshots']
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(2)
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True

    torch.set_num_threads(1)

    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(),device_ids=[2])
    model.train()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][config['label_sets']]
    transforms = Compose([Resize(size=config['input_size']),
                          RandomMirror()])
    dataset = make_data(
        base_dir=config['path'][data_name]['data_dir'],
        split=config['path'][data_name]['data_split'],
        transforms=transforms,
        to_tensor=ToTensorNormalize(),
        labels=labels,
        max_iters=config['n_steps'] * config['batch_size'],
        n_ways=config['task']['n_ways'],
        n_shots=config['task']['n_shots'],
        n_queries=config['task']['n_queries']
    )
    trainloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True
    )

    optimizer = torch.optim.SGD(model.parameters(), **config['optim'])
    scheduler = MultiStepLR(optimizer, milestones=config['lr_milestones'], gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}

    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_fg_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_bg_mask']]
        img_size = sample_batched['img_size']
        # support_label_t = [[shot.float().cuda() for shot in way]
        #                    for way in sample_batched['support_bg_mask']]
        query_images = [query_image.cuda()
                        for query_image in sample_batched['query_images']]
        query_labels = torch.cat(
            [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0)
        pre_masks = [query_label.float().cuda() for query_label in sample_batched['query_masks']]
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask, support_bg_mask,
                                       query_images,pre_masks)
        query_pred = F.interpolate(query_pred, size=img_size, mode= "bilinear")
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * config['align_loss_scaler']

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0

        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}')
            # if len(support_fg_mask)>1:
            #     pred = query_pred.argmax(dim=1, keepdim=True)
            #     pred = pred.data.cpu().numpy()
            #     img = pred[0, 0]
            #     for i in range(img.shape[0]):
            #         for j in range(img.shape[1]):
            #             if img[i][j] > 0:
            #                 print(f'{img[i][j]} {len(support_fg_mask)}')
            #
            #     img_e = Image.fromarray(img.astype('float32')).convert('P')
            #     img_e.putpalette(palette)
            #     img_e.save(os.path.join(config['path']['davis']['data_dir'], '{:05d}.png'.format(i_iter)))

        if (i_iter + 1) % config['save_pred_every'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))

    torch.save(model.state_dict(),
               os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
예제 #3
0
파일: train.py 프로젝트: ZHUXUHAN/panet
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    # cudnn.enabled = True
    # cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()
    # print(model)

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
    elif data_name == 'COCO':
        make_data = coco_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    solver_dict = {
        'gamma': 0.1,
        'warm_up_iters': 2000,
        'steps': _config['lr_milestones'],
        'base_lr': 1e-3,
        'warm_up_factor': 0.1
    }

    # scheduler = Adjust_Learning_Rate(optimizer, solver_dict)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)

        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask,
                                       support_bg_mask, query_images)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * _config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(
                model.state_dict(),
                os.path.join(f'{_run.observers[0].dir}/snapshots',
                             f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(
        model.state_dict(),
        os.path.join(f'{_run.observers[0].dir}/snapshots',
                     f'{i_iter + 1}.pth'))
예제 #4
0
파일: train.py 프로젝트: yanjj16/PANet_at1
def main(config):
    if not os.path.exists(config['snapshots']):
        os.makedirs(config['snapshots'])
    snap_shots_dir = config['snapshots']
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(2)
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True

    torch.set_num_threads(1)

    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[2])
    model.train()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][config['label_sets']]
    transforms = Compose([Resize(size=config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=config['path'][data_name]['data_dir'],
                        split=config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=config['n_steps'] * config['batch_size'],
                        n_ways=config['task']['n_ways'],
                        n_shots=config['task']['n_shots'],
                        n_queries=config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)

    optimizer = torch.optim.SGD(model.parameters(), **config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}

    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_fg_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_bg_mask']]
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)
        # print(query_labels.shape)
        pre_masks = [
            query_label.float().cuda()
            for query_label in sample_batched['query_masks']
        ]
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask,
                                       support_bg_mask, query_images,
                                       pre_masks)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * config['align_loss_scaler']

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        # _run.log_scalar('loss', query_loss)
        # _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}')

        if (i_iter + 1) % config['save_pred_every'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))

    torch.save(model.state_dict(),
               os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
예제 #5
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'BCV':
        make_data = meta_data
    else:
        print(f"data name : {data_name}")
        raise ValueError('Wrong config for dataset!')

    tr_dataset, val_dataset, ts_dataset = make_data(_config)
    trainloader = DataLoader(
        dataset=tr_dataset,
        batch_size=_config['batch_size'],
        shuffle=True,
        num_workers=_config['n_work'],
        pin_memory=False,  #True load data while training gpu
        drop_last=True)
    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    if _config['record']:  ## tensorboard visualization
        _log.info('###### define tensorboard writer #####')
        _log.info(f'##### board/train_{_config["board"]}_{date()}')
        writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}')

    log_loss = {'loss': 0, 'align_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        s_x_orig = sample_batched['s_x'].cuda(
        )  # [B, Support, slice_num=1, 1, 256, 256]
        s_x = s_x_orig.squeeze(2)  # [B, Support, 1, 256, 256]
        s_y_fg_orig = sample_batched['s_y'].cuda(
        )  # [B, Support, slice_num, 1, 256, 256]
        s_y_fg = s_y_fg_orig.squeeze(2)  # [B, Support, 1, 256, 256]
        s_y_fg = s_y_fg.squeeze(2)  # [B, Support, 256, 256]
        s_y_bg = torch.ones_like(s_y_fg) - s_y_fg
        q_x_orig = sample_batched['q_x'].cuda()  # [B, slice_num, 1, 256, 256]
        q_x = q_x_orig.squeeze(1)  # [B, 1, 256, 256]
        q_y_orig = sample_batched['q_y'].cuda()  # [B, slice_num, 1, 256, 256]
        q_y = q_y_orig.squeeze(1)  # [B, 1, 256, 256]
        q_y = q_y.squeeze(1).long()  # [B, 256, 256]
        s_xs = [[s_x[:, shot, ...] for shot in range(_config["n_shot"])]]
        s_y_fgs = [[s_y_fg[:, shot, ...] for shot in range(_config["n_shot"])]]
        s_y_bgs = [[s_y_bg[:, shot, ...] for shot in range(_config["n_shot"])]]
        q_xs = [q_x]
        """
        Args:
            supp_imgs: support images
                way x shot x [B x 1 x H x W], list of lists of tensors
            fore_mask: foreground masks for support images
                way x shot x [B x H x W], list of lists of tensors
            back_mask: background masks for support images
                way x shot x [B x H x W], list of lists of tensors
            qry_imgs: query images
                N x [B x 1 x H x W], list of tensors
            qry_pred: [B, 2, H, W]
        """

        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(s_xs, s_y_fgs, s_y_bgs,
                                       q_xs)  #[B, 2, w, h]
        query_loss = criterion(query_pred, q_y)
        loss = query_loss + align_loss * _config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}')

            if _config['record']:
                batch_i = 0
                frames = []
                query_pred = query_pred.argmax(dim=1)
                query_pred = query_pred.unsqueeze(1)
                frames += overlay_color(q_x_orig[batch_i, 0],
                                        query_pred[batch_i].float(),
                                        q_y_orig[batch_i, 0])
                visual = make_grid(frames, normalize=True, nrow=2)
                writer.add_image("train/visual", visual, i_iter)

            print(f"train - iter:{i_iter} \t => model saved", end='\n')
            save_fname = f'{_run.observers[0].dir}/snapshots/last.pth'
            torch.save(model.state_dict(), save_fname)