def upsnet_train():

    if is_master:
        logger.info('training config:{}\n'.format(pprint.pformat(config)))
    gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')]
    num_replica = hvd.size() if config.train.use_horovod else len(gpus)
    num_gpus = 1 if config.train.use_horovod else len(gpus)

    # create models
    train_model = eval(config.symbol)().cuda()
        
    # create optimizer
    params_lr = train_model.get_params_lr()
    # we use custom optimizer and pass lr=1 to support different lr for different weights
    optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd)
    if config.train.use_horovod:
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters())
    optimizer.zero_grad()

    # create data loader
    train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path)
    val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val')
    if config.train.use_horovod:
        train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)

    # preparing
    curr_iter = config.train.begin_iteration
    batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)]
    metrics = []
    metrics_name = []
    if config.network.has_rpn:
        metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),])
        metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss'])
    if config.network.has_rcnn:
        metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),])
        metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss'])
    if config.network.has_mask_head:
        metrics.extend([AvgMetric(name='mask_loss'), ])
        metrics_name.extend(['mask_loss'])
    if config.network.has_fcn_head:
        metrics.extend([AvgMetric(name='fcn_loss'), ])
        metrics_name.extend(['fcn_loss'])
        if config.train.fcn_with_roi_loss:
            metrics.extend([AvgMetric(name='fcn_roi_loss'), ])
            metrics_name.extend(['fcn_roi_loss'])
    if config.network.has_panoptic_head:
        metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ])
        metrics_name.extend(['panoptic_accuracy', 'panoptic_loss'])

    if config.train.resume:
        train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True)
        optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')))
        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)
    else:
        if is_master:
            train_model.load_state_dict(torch.load(config.network.pretrained))

        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)

    if not config.train.use_horovod:
        train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0])

    if is_master:
        batch_end_callback[0](0, 0)

    train_model.eval()

    # start training
    while curr_iter < config.train.max_iteration:
        if config.train.use_horovod:
            train_sampler.set_epoch(curr_iter)

            if config.network.use_syncbn:
                train_model.train()
                if config.network.backbone_freeze_at > 0:
                    train_model.freeze_backbone(config.network.backbone_freeze_at)
                if config.network.backbone_fix_bn:
                    train_model.resnet_backbone.eval()


            for inner_iter, batch in enumerate(train_loader):
                data, label, _ = batch
                for k, v in data.items():
                    data[k] = v if not torch.is_tensor(v) else v.cuda()
                for k, v in label.items():
                    label[k] = v if not torch.is_tensor(v) else v.cuda()

                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                output = train_model(data, label)
                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)

                losses = []
                losses.append(allreduce_async(loss, name='train_total_loss'))
                for l in metrics_name:
                    losses.append(allreduce_async(output[l].mean(), name=l))

                loss = hvd.synchronize(losses[0]).item()
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = hvd.synchronize(losses[i + 1]).item()
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1


                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        if 'momentum_buffer' in optimizer.state_dict()['state'][k]:
                            optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)

                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))
        else:
            inner_iter = 0
            train_iterator = train_loader.__iter__()
            while inner_iter + num_gpus <= len(train_loader):
                batch = []
                for gpu_id in gpus:
                    data, label, _ = train_iterator.next()
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    batch.append((data, label))
                    inner_iter += 1
                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                if config.train.use_horovod:
                    output = train_model(data, label)
                else:
                    output = train_model(*batch)

                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean()
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)
                
                losses = []
                losses.append(loss.item())
                for l in metrics_name:
                    losses.append(output[l].mean().item())

                loss = losses[0]
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = losses[i + 1]
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1

                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)


                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))

            while True:
                try:
                    train_iterator.next()
                except:
                    break

        for metric in metrics:
            metric.reset()

        if config.train.eval_data:
            train_model.eval()

            if config.train.use_horovod:
                for inner_iter, batch in enumerate(val_loader):
                    data, label, _ = batch
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)

                    with torch.no_grad():
                        output = train_model(data, label)

                    for metric, l in zip(metrics, metrics_name):
                        loss = hvd.allreduce(output[l].mean()).item()
                        if is_master:
                            metric.update(_, _, loss)

            else:
                inner_iter = 0
                val_iterator = val_loader.__iter__()
                while inner_iter + len(gpus) <= len(val_loader):
                    batch = []
                    for gpu_id in gpus:
                        data, label, _ = val_iterator.next()
                        for k, v in data.items():
                            data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        for k, v in label.items():
                            label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        batch.append((data, label))
                        inner_iter += 1

                    with torch.no_grad():
                        if config.train.use_horovod:
                            output = train_model(data, label)
                        else:
                            output = train_model(*batch)

                    losses = []
                    for l in metrics_name:
                        losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item())

                    for metric, loss in zip(metrics, losses):
                        loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss
                        if is_master:
                            metric.update(_, _, loss)

                while True:
                    try:
                        val_iterator.next()
                    except Exception:
                        break

            s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader))

            for metric in metrics:
                m, v = metric.get()
                s += 'Val-%s=%f,\t' % (m, v)
                if is_master:
                    writer.add_scalar('val_' + m, v, curr_iter)
                    metric.reset()
            if is_master:
                logger.info(s)

    if is_master and config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
    elif not config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
Exemplo n.º 2
0
def main():
    """Training for softmax classifier only.
  """
    # Retreve experiment configurations.
    args = parse_args('Training for softmax classifier only.')

    # Retrieve GPU informations.
    device_ids = [int(i) for i in config.gpus.split(',')]
    gpu_ids = [torch.device('cuda', i) for i in device_ids]
    num_gpus = len(gpu_ids)

    # Create logger and tensorboard writer.
    summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir)
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)

    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    optimizer_path_template = os.path.join(args.snapshot_dir,
                                           'model-{:d}.state.pth')

    # Create data loaders.
    train_dataset = ListTagClassifierDataset(
        data_dir=args.data_dir,
        data_list=args.data_list,
        img_mean=config.network.pixel_means,
        img_std=config.network.pixel_stds,
        size=config.train.crop_size,
        random_crop=config.train.random_crop,
        random_scale=config.train.random_scale,
        random_mirror=config.train.random_mirror,
        random_grayscale=True,
        random_blur=True,
        training=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=config.train.shuffle,
        num_workers=num_gpus * config.num_threads,
        drop_last=False,
        collate_fn=train_dataset.collate_fn)

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    elif config.network.backbone_types == 'panoptic_deeplab_101':
        embedding_model = resnet_101_deeplab(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'softmax_classifier':
        prediction_model = softmax_classifier(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    # Use customized optimizer and pass lr=1 to support different lr for
    # different weights.
    optimizer = SGD(embedding_model.get_params_lr() +
                    prediction_model.get_params_lr(),
                    lr=1,
                    momentum=config.train.momentum,
                    weight_decay=config.train.weight_decay)
    optimizer.zero_grad()

    # Load pre-trained weights.
    curr_iter = config.train.begin_iteration
    if config.network.pretrained:
        print('Loading pre-trained model: {:s}'.format(
            config.network.pretrained))
        embedding_model.load_state_dict(torch.load(
            config.network.pretrained)['embedding_model'],
                                        resume=True)
    else:
        raise ValueError('Pre-trained model is required.')

    # Distribute model weights to multi-gpus.
    embedding_model = DataParallel(embedding_model,
                                   device_ids=device_ids,
                                   gather_output=False)
    prediction_model = DataParallel(prediction_model,
                                    device_ids=device_ids,
                                    gather_output=False)

    embedding_model.eval()
    prediction_model.train()
    print(embedding_model)
    print(prediction_model)

    # Create memory bank.
    memory_banks = {}

    # start training
    train_iterator = train_loader.__iter__()
    iterator_index = 0
    pbar = tqdm(range(curr_iter, config.train.max_iteration))
    for curr_iter in pbar:
        # Check if the rest of datas is enough to iterate through;
        # otherwise, re-initiate the data iterator.
        if iterator_index + num_gpus >= len(train_loader):
            train_iterator = train_loader.__iter__()
            iterator_index = 0

        # Feed-forward.
        image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu(
            train_iterator, gpu_ids)
        iterator_index += num_gpus

        # Generate embeddings, clustering and prototypes.
        with torch.no_grad():
            embeddings = embedding_model(*zip(image_batch, label_batch))

        # Compute loss.
        outputs = prediction_model(*zip(embeddings, label_batch))
        outputs = scatter_gather.gather(outputs, gpu_ids[0])
        losses = []
        for k in ['sem_ann_loss']:
            loss = outputs.get(k, None)
            if loss is not None:
                outputs[k] = loss.mean()
                losses.append(outputs[k])
        loss = sum(losses)
        acc = outputs['accuracy'].mean()

        # Backward propogation.
        if config.train.lr_policy == 'step':
            lr = train_utils.lr_step(config.train.base_lr, curr_iter,
                                     config.train.decay_iterations,
                                     config.train.warmup_iteration)
        else:
            lr = train_utils.lr_poly(config.train.base_lr, curr_iter,
                                     config.train.max_iteration,
                                     config.train.warmup_iteration)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step(lr)

        # Snapshot the trained model.
        if ((curr_iter + 1) % config.train.snapshot_step == 0
                or curr_iter == config.train.max_iteration - 1):
            model_state_dict = {
                'embedding_model': embedding_model.module.state_dict(),
                'prediction_model': prediction_model.module.state_dict()
            }
            torch.save(model_state_dict, model_path_template.format(curr_iter))
            torch.save(optimizer.state_dict(),
                       optimizer_path_template.format(curr_iter))

        # Print loss in the progress bar.
        line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format(
            loss.item(), acc.item(), lr)
        pbar.set_description(line)
Exemplo n.º 3
0
def main():
    """Training for pixel-wise embeddings by pixel-segment
  contrastive learning loss for DensePose.
  """
    # Retreve experiment configurations.
    args = parse_args('Training for pixel-wise embeddings for DensePose.')

    # Retrieve GPU informations.
    device_ids = [int(i) for i in config.gpus.split(',')]
    gpu_ids = [torch.device('cuda', i) for i in device_ids]
    num_gpus = len(gpu_ids)

    # Create logger and tensorboard writer.
    summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir)
    color_map = vis_utils.load_color_map(config.dataset.color_map_path)

    model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth')
    optimizer_path_template = os.path.join(args.snapshot_dir,
                                           'model-{:d}.state.pth')

    # Create data loaders.
    train_dataset = DenseposeDataset(data_dir=args.data_dir,
                                     data_list=args.data_list,
                                     img_mean=config.network.pixel_means,
                                     img_std=config.network.pixel_stds,
                                     size=config.train.crop_size,
                                     random_crop=config.train.random_crop,
                                     random_scale=config.train.random_scale,
                                     random_mirror=config.train.random_mirror,
                                     training=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=config.train.shuffle,
        num_workers=num_gpus * config.num_threads,
        drop_last=False,
        collate_fn=train_dataset.collate_fn)

    # Create models.
    if config.network.backbone_types == 'panoptic_pspnet_101':
        embedding_model = resnet_101_pspnet(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.backbone_types)

    if config.network.prediction_types == 'segsort':
        prediction_model = segsort(config).cuda()
    else:
        raise ValueError('Not support ' + config.network.prediction_types)

    # Use synchronize batchnorm.
    if config.network.use_syncbn:
        embedding_model = convert_model(embedding_model).cuda()
        prediction_model = convert_model(prediction_model).cuda()

    # Use customized optimizer and pass lr=1 to support different lr for
    # different weights.
    optimizer = SGD(embedding_model.get_params_lr() +
                    prediction_model.get_params_lr(),
                    lr=1,
                    momentum=config.train.momentum,
                    weight_decay=config.train.weight_decay)
    optimizer.zero_grad()

    # Load pre-trained weights.
    curr_iter = config.train.begin_iteration
    if config.train.resume:
        model_path = model_path_template.fromat(curr_iter)
        print('Resume training from {:s}'.format(model_path))
        embedding_model.load_state_dict(
            torch.load(model_path)['embedding_model'], resume=True)
        prediction_model.load_state_dict(
            torch.load(model_path)['prediction_model'], resume=True)
        optimizer.load_state_dict(
            torch.load(optimizer_path_template.format(curr_iter)))
    elif config.network.pretrained:
        print('Loading pre-trained model: {:s}'.format(
            config.network.pretrained))
        embedding_model.load_state_dict(torch.load(config.network.pretrained))
    else:
        print('Training from scratch')

    # Distribute model weights to multi-gpus.
    embedding_model = DataParallel(embedding_model,
                                   device_ids=device_ids,
                                   gather_output=False)
    prediction_model = DataParallel(prediction_model,
                                    device_ids=device_ids,
                                    gather_output=False)
    if config.network.use_syncbn:
        patch_replication_callback(embedding_model)
        patch_replication_callback(prediction_model)

    for module in embedding_model.modules():
        if isinstance(module, _BatchNorm) or isinstance(module, _ConvNd):
            print(module.training, module)
    print(embedding_model)
    print(prediction_model)

    # Create memory bank.
    memory_banks = {}

    # start training
    train_iterator = train_loader.__iter__()
    iterator_index = 0
    pbar = tqdm(range(curr_iter, config.train.max_iteration))
    for curr_iter in pbar:
        # Check if the rest of datas is enough to iterate through;
        # otherwise, re-initiate the data iterator.
        if iterator_index + num_gpus >= len(train_loader):
            train_iterator = train_loader.__iter__()
            iterator_index = 0

        # Feed-forward.
        image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu(
            train_iterator, gpu_ids)
        iterator_index += num_gpus

        # Generate embeddings, clustering and prototypes.
        embeddings = embedding_model(*zip(image_batch, label_batch))

        # Synchronize cluster indices and computer prototypes.
        c_inds = [emb['cluster_index'] for emb in embeddings]
        cb_inds = [emb['cluster_batch_index'] for emb in embeddings]
        cs_labs = [emb['cluster_semantic_label'] for emb in embeddings]
        ci_labs = [emb['cluster_instance_label'] for emb in embeddings]
        c_embs = [emb['cluster_embedding'] for emb in embeddings]
        c_embs_with_loc = [
            emb['cluster_embedding_with_loc'] for emb in embeddings
        ]
        (prototypes, prototypes_with_loc, prototype_semantic_labels,
         prototype_instance_labels, prototype_batch_indices,
         cluster_indices) = (
             model_utils.gather_clustering_and_update_prototypes(
                 c_embs, c_embs_with_loc, c_inds, cb_inds, cs_labs, ci_labs,
                 'cuda:{:d}'.format(num_gpus - 1)))

        for i in range(len(label_batch)):
            label_batch[i]['prototype'] = prototypes[i]
            label_batch[i]['prototype_with_loc'] = prototypes_with_loc[i]
            label_batch[i][
                'prototype_semantic_label'] = prototype_semantic_labels[i]
            label_batch[i][
                'prototype_instance_label'] = prototype_instance_labels[i]
            label_batch[i]['prototype_batch_index'] = prototype_batch_indices[
                i]
            embeddings[i]['cluster_index'] = cluster_indices[i]

        #semantic_tags = model_utils.gather_and_update_datas(
        #    [lab['semantic_tag'] for lab in label_batch],
        #    'cuda:{:d}'.format(num_gpus-1))
        #for i in range(len(label_batch)):
        #  label_batch[i]['semantic_tag'] = semantic_tags[i]
        #  label_batch[i]['prototype_semantic_tag'] = torch.index_select(
        #      semantic_tags[i],
        #      0,
        #      label_batch[i]['prototype_batch_index'])

        # Add memory bank to label batch.
        for k in memory_banks.keys():
            for i in range(len(label_batch)):
                assert (label_batch[i].get(k, None) is None)
                label_batch[i][k] = [m.to(gpu_ids[i]) for m in memory_banks[k]]

        # Compute loss.
        outputs = prediction_model(*zip(embeddings, label_batch))
        outputs = scatter_gather.gather(outputs, gpu_ids[0])
        losses = []
        for k in [
                'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss'
        ]:
            loss = outputs.get(k, None)
            if loss is not None:
                outputs[k] = loss.mean()
                losses.append(outputs[k])
        loss = sum(losses)
        acc = outputs['accuracy'].mean()

        # Write to tensorboard summary.
        writer = (summary_writer if curr_iter %
                  config.train.tensorboard_step == 0 else None)
        if writer is not None:
            summary_vis = []
            summary_val = {}
            # Gather labels to cpu.
            cpu_label_batch = scatter_gather.gather(label_batch, -1)
            summary_vis.append(
                vis_utils.convert_label_to_color(
                    cpu_label_batch['semantic_label'], color_map))
            summary_vis.append(
                vis_utils.convert_label_to_color(
                    cpu_label_batch['instance_label'], color_map))

            # Gather outputs to cpu.
            vis_names = ['embedding']
            cpu_embeddings = scatter_gather.gather(
                [{k: emb.get(k, None)
                  for k in vis_names} for emb in embeddings], -1)
            for vis_name in vis_names:
                if cpu_embeddings.get(vis_name, None) is not None:
                    summary_vis.append(
                        vis_utils.embedding_to_rgb(cpu_embeddings[vis_name],
                                                   'pca'))

            val_names = [
                'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss',
                'feat_aff_loss', 'accuracy'
            ]
            for val_name in val_names:
                if outputs.get(val_name, None) is not None:
                    summary_val[val_name] = outputs[val_name].mean().to('cpu')

            vis_utils.write_image_to_tensorboard(summary_writer, summary_vis,
                                                 summary_vis[-1].shape[-2:],
                                                 curr_iter)
            vis_utils.write_scalars_to_tensorboard(summary_writer, summary_val,
                                                   curr_iter)

        # Backward propogation.
        if config.train.lr_policy == 'step':
            lr = train_utils.lr_step(config.train.base_lr, curr_iter,
                                     config.train.decay_iterations,
                                     config.train.warmup_iteration)
        else:
            lr = train_utils.lr_poly(config.train.base_lr, curr_iter,
                                     config.train.max_iteration,
                                     config.train.warmup_iteration)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step(lr)

        # Update memory banks.
        with torch.no_grad():
            for k in label_batch[0].keys():
                if 'prototype' in k and 'memory' not in k:
                    memory = label_batch[0][k].clone().detach()
                    memory_key = 'memory_' + k
                    if memory_key not in memory_banks.keys():
                        memory_banks[memory_key] = []
                    memory_banks[memory_key].append(memory)
                    if len(memory_banks[memory_key]
                           ) > config.train.memory_bank_size:
                        memory_banks[memory_key] = memory_banks[memory_key][1:]

            # Update batch labels.
            for k in ['memory_prototype_batch_index']:
                memory_labels = memory_banks.get(k, None)
                if memory_labels is not None:
                    for i, memory_label in enumerate(memory_labels):
                        memory_labels[i] += config.train.batch_size * num_gpus

        # Snapshot the trained model.
        if ((curr_iter + 1) % config.train.snapshot_step == 0
                or curr_iter == config.train.max_iteration - 1):
            model_state_dict = {
                'embedding_model': embedding_model.module.state_dict(),
                'prediction_model': prediction_model.module.state_dict()
            }
            torch.save(model_state_dict, model_path_template.format(curr_iter))
            torch.save(optimizer.state_dict(),
                       optimizer_path_template.format(curr_iter))

        # Print loss in the progress bar.
        line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format(
            loss.item(), acc.item(), lr)
        pbar.set_description(line)