示例#1
0
文件: snn_eval.py 项目: zlapp/suncet
    def init_pipe(training):
        # -- make data transforms
        transform, init_transform = make_transforms(
            dataset_name=dataset_name,
            subset_path=subset_path,
            unlabeled_frac=unlabeled_frac if training else 0.,
            training=training,
            split_seed=split_seed,
            basic_augmentations=True,
            force_center_crop=True,
            normalize=normalize)

        # -- init data-loaders/samplers
        (data_loader, data_sampler) = init_data(dataset_name=dataset_name,
                                                transform=transform,
                                                init_transform=init_transform,
                                                u_batch_size=None,
                                                s_batch_size=16,
                                                stratify=False,
                                                classes_per_batch=None,
                                                world_size=1,
                                                rank=0,
                                                root_path=root_path,
                                                image_folder=image_folder,
                                                training=training,
                                                copy_data=False,
                                                drop_last=False)

        return transform, init_transform, data_loader, data_sampler
示例#2
0
def main(args):

    # ----------------------------------------------------------------------- #
    #  PASSED IN PARAMS FROM CONFIG FILE
    # ----------------------------------------------------------------------- #
    # -- META
    model_name = args['meta']['model_name']
    output_dim = args['meta']['output_dim']
    load_model = args['meta']['load_checkpoint']
    r_file = args['meta']['read_checkpoint']
    copy_data = args['meta']['copy_data']
    use_fp16 = args['meta']['use_fp16']
    use_pred_head = args['meta']['use_pred_head']
    device = torch.device(args['meta']['device'])
    torch.cuda.set_device(device)

    # -- CRITERTION
    reg = args['criterion']['me_max']
    supervised_views = args['criterion']['supervised_views']
    classes_per_batch = args['criterion']['classes_per_batch']
    s_batch_size = args['criterion']['supervised_imgs_per_class']
    u_batch_size = args['criterion']['unsupervised_batch_size']
    temperature = args['criterion']['temperature']
    sharpen = args['criterion']['sharpen']

    # -- DATA
    unlabeled_frac = args['data']['unlabeled_frac']
    color_jitter = args['data']['color_jitter_strength']
    normalize = args['data']['normalize']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    dataset_name = args['data']['dataset']
    subset_path = args['data']['subset_path']
    unique_classes = args['data']['unique_classes_per_rank']
    multicrop = args['data']['multicrop']
    label_smoothing = args['data']['label_smoothing']
    data_seed = None
    if 'cifar10' in dataset_name:
        data_seed = args['data']['data_seed']
        crop_scale = (0.75, 1.0) if multicrop > 0 else (0.5, 1.0)
        mc_scale = (0.3, 0.75)
        mc_size = 18
    else:
        crop_scale = (0.14, 1.0) if multicrop > 0 else (0.08, 1.0)
        mc_scale = (0.05, 0.14)
        mc_size = 96

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    num_epochs = args['optimization']['epochs']
    warmup = args['optimization']['warmup']
    start_lr = args['optimization']['start_lr']
    lr = args['optimization']['lr']
    final_lr = args['optimization']['final_lr']
    mom = args['optimization']['momentum']
    nesterov = args['optimization']['nesterov']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    # ----------------------------------------------------------------------- #

    # -- init torch distributed backend
    world_size, rank = init_distributed()
    logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')

    # -- log/checkpointing paths
    log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
    save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
    latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
    best_path = os.path.join(folder, f'{tag}' + '-best.pth.tar')
    load_path = None
    if load_model:
        load_path = os.path.join(folder,
                                 r_file) if r_file is not None else latest_path

    # -- make csv_logger
    csv_logger = CSVLogger(log_file, ('%d', 'epoch'), ('%d', 'itr'),
                           ('%.5f', 'paws-xent-loss'),
                           ('%.5f', 'paws-me_max-reg'), ('%d', 'time (ms)'))

    # -- init model
    encoder = init_model(device=device,
                         model_name=model_name,
                         use_pred=use_pred_head,
                         output_dim=output_dim)
    if world_size > 1:
        process_group = apex.parallel.create_syncbn_process_group(0)
        encoder = apex.parallel.convert_syncbn_model(
            encoder, process_group=process_group)

    # -- init losses
    paws = init_paws_loss(multicrop=multicrop,
                          tau=temperature,
                          T=sharpen,
                          me_max=reg)
    # -- assume support images are sampled with ClassStratifiedSampler
    labels_matrix = make_labels_matrix(num_classes=classes_per_batch,
                                       s_batch_size=s_batch_size,
                                       world_size=world_size,
                                       device=device,
                                       unique_classes=unique_classes,
                                       smoothing=label_smoothing)

    # -- make data transforms
    transform, init_transform = make_transforms(dataset_name=dataset_name,
                                                subset_path=subset_path,
                                                unlabeled_frac=unlabeled_frac,
                                                training=True,
                                                split_seed=data_seed,
                                                crop_scale=crop_scale,
                                                basic_augmentations=False,
                                                color_jitter=color_jitter,
                                                normalize=normalize)
    multicrop_transform = (multicrop, None)
    if multicrop > 0:
        multicrop_transform = make_multicrop_transform(
            dataset_name=dataset_name,
            num_crops=multicrop,
            size=mc_size,
            crop_scale=mc_scale,
            normalize=normalize,
            color_distortion=color_jitter)

    # -- init data-loaders/samplers
    (unsupervised_loader, unsupervised_sampler, supervised_loader,
     supervised_sampler) = init_data(dataset_name=dataset_name,
                                     transform=transform,
                                     init_transform=init_transform,
                                     supervised_views=supervised_views,
                                     u_batch_size=u_batch_size,
                                     s_batch_size=s_batch_size,
                                     unique_classes=unique_classes,
                                     classes_per_batch=classes_per_batch,
                                     multicrop_transform=multicrop_transform,
                                     world_size=world_size,
                                     rank=rank,
                                     root_path=root_path,
                                     image_folder=image_folder,
                                     training=True,
                                     copy_data=copy_data)
    iter_supervised = None
    ipe = len(unsupervised_loader)
    logger.info(f'iterations per epoch: {ipe}')

    # -- init optimizer and scheduler
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
    encoder, optimizer, scheduler = init_opt(encoder=encoder,
                                             weight_decay=wd,
                                             start_lr=start_lr,
                                             ref_lr=lr,
                                             final_lr=final_lr,
                                             ref_mom=mom,
                                             nesterov=nesterov,
                                             iterations_per_epoch=ipe,
                                             warmup=warmup,
                                             num_epochs=num_epochs)
    if world_size > 1:
        encoder = DistributedDataParallel(encoder, broadcast_buffers=False)

    start_epoch = 0
    # -- load training checkpoint
    if load_model:
        encoder, optimizer, start_epoch = load_checkpoint(r_path=load_path,
                                                          encoder=encoder,
                                                          opt=optimizer,
                                                          scaler=scaler,
                                                          use_fp16=use_fp16)
        for _ in range(start_epoch):
            for _ in range(ipe):
                scheduler.step()

    # -- TRAINING LOOP
    best_loss = None
    for epoch in range(start_epoch, num_epochs):
        logger.info('Epoch %d' % (epoch + 1))

        # -- update distributed-data-loader epoch
        unsupervised_sampler.set_epoch(epoch)
        if supervised_sampler is not None:
            supervised_sampler.set_epoch(epoch)

        loss_meter = AverageMeter()
        ploss_meter = AverageMeter()
        rloss_meter = AverageMeter()
        time_meter = AverageMeter()
        data_meter = AverageMeter()

        for itr, udata in enumerate(unsupervised_loader):
            if use_fp16:
                udata = [u.half() for u in udata]

            def load_imgs():
                # -- unsupervised imgs (2 views)
                imgs = [u.to(device) for u in udata[:2]]

                # -- unsupervised multicrop img views
                mc_imgs = None
                if multicrop > 0:
                    mc_imgs = torch.cat([u.to(device) for u in udata[2:-1]],
                                        dim=0)

                # -- labeled support imgs
                global iter_supervised
                try:
                    sdata = next(iter_supervised)
                except Exception:
                    iter_supervised = iter(supervised_loader)
                    logger.info(
                        f'len.supervised_loader: {len(iter_supervised)}')
                    sdata = next(iter_supervised)
                finally:
                    if use_fp16:
                        sdata = [s.half() for s in sdata]
                    simgs = [s.to(device) for s in sdata[:-1]]
                    labels = torch.cat(
                        [labels_matrix for _ in range(supervised_views)])

                # -- concatenate unlabeled images and labeled support images
                imgs = torch.cat(imgs + simgs, dim=0)

                return imgs, mc_imgs, labels

            (imgs, mc_imgs, labels), dtime = gpu_timer(load_imgs)
            data_meter.update(dtime)

            def train_step():
                with torch.cuda.amp.autocast(enabled=use_fp16):
                    optimizer.zero_grad()
                    (h, z), z_mc = encoder(imgs,
                                           mc_imgs,
                                           return_before_head=True)

                    # Compute paws loss in full precision
                    with torch.cuda.amp.autocast(enabled=False):

                        # Step 1. convert representations to fp32
                        h, z = h.float(), z.float()
                        if z_mc is not None:
                            z_mc = z_mc.float()

                        # Step 2. determine anchor views/supports and their
                        #         corresponding target views/supports
                        if not use_pred_head:
                            h = z
                        target_supports = h[2 * u_batch_size:].detach()
                        target_views = h[:2 * u_batch_size].detach()
                        target_views = torch.cat([
                            target_views[u_batch_size:],
                            target_views[:u_batch_size]
                        ],
                                                 dim=0)
                        # --
                        anchor_supports = z[2 * u_batch_size:]
                        anchor_views = z[:2 * u_batch_size]
                        if multicrop > 0:
                            anchor_views = torch.cat([anchor_views, z_mc],
                                                     dim=0)

                        # Step 3. compute paws loss with me-max regularization
                        (ploss, me_max) = paws(anchor_views=anchor_views,
                                               anchor_supports=anchor_supports,
                                               anchor_support_labels=labels,
                                               target_views=target_views,
                                               target_supports=target_supports,
                                               target_support_labels=labels)
                        loss = ploss + me_max

                scaler.scale(loss).backward()
                lr_stats = scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                return (float(loss), float(ploss), float(me_max), lr_stats)

            (loss, ploss, rloss, lr_stats), etime = gpu_timer(train_step)
            loss_meter.update(loss)
            ploss_meter.update(ploss)
            rloss_meter.update(rloss)
            time_meter.update(etime)

            if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                csv_logger.log(epoch + 1, itr, ploss_meter.avg,
                               rloss_meter.avg, time_meter.avg)
                logger.info('[%d, %5d] loss: %.3f (%.3f %.3f) '
                            '(%d ms; %d ms)' %
                            (epoch + 1, itr, loss_meter.avg, ploss_meter.avg,
                             rloss_meter.avg, time_meter.avg, data_meter.avg))
                if lr_stats is not None:
                    logger.info('[%d, %5d] lr_stats: %.3f (%.2e, %.2e)' %
                                (epoch + 1, itr, lr_stats.avg, lr_stats.min,
                                 lr_stats.max))

            assert not np.isnan(loss), 'loss is nan'

        # -- logging/checkpointing
        logger.info('avg. loss %.3f' % loss_meter.avg)

        if rank == 0:
            save_dict = {
                'encoder': encoder.state_dict(),
                'opt': optimizer.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'loss': loss_meter.avg,
                's_batch_size': s_batch_size,
                'u_batch_size': u_batch_size,
                'world_size': world_size,
                'lr': lr,
                'temperature': temperature,
                'amp': scaler.state_dict()
            }
            torch.save(save_dict, latest_path)
            if best_loss is None or best_loss > loss_meter.avg:
                best_loss = loss_meter.avg
                logger.info('updating "best" checkpoint')
                torch.save(save_dict, best_path)
            if (epoch + 1) % checkpoint_freq == 0 \
                    or (epoch + 1) % 10 == 0 and epoch < checkpoint_freq:
                torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))
示例#3
0
def main(args):

    # -- META
    model_name = args['meta']['model_name']
    port = args['meta']['master_port']
    load_checkpoint = args['meta']['load_checkpoint']
    training = args['meta']['training']
    copy_data = args['meta']['copy_data']
    use_fp16 = args['meta']['use_fp16']
    device = torch.device(args['meta']['device'])
    torch.cuda.set_device(device)

    # -- DATA
    unlabeled_frac = args['data']['unlabeled_frac']
    normalize = args['data']['normalize']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    dataset_name = args['data']['dataset']
    subset_path = args['data']['subset_path']
    num_classes = args['data']['num_classes']
    data_seed = None
    if 'cifar10' in dataset_name:
        data_seed = args['data']['data_seed']
    crop_scale = (0.5, 1.0) if 'cifar10' in dataset_name else (0.08, 1.0)

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    ref_lr = args['optimization']['lr']
    use_lars = args['optimization']['use_lars']
    zero_init = args['optimization']['zero_init']
    num_epochs = args['optimization']['epochs']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    r_file_enc = args['logging']['pretrain_path']

    # -- log/checkpointing paths
    r_enc_path = os.path.join(folder, r_file_enc)
    w_enc_path = os.path.join(folder, f'{tag}-fine-tune.pth.tar')

    # -- init distributed
    world_size, rank = init_distributed(port)
    logger.info(f'initialized rank/world-size: {rank}/{world_size}')

    # -- optimization/evaluation params
    if training:
        batch_size = 256
    else:
        batch_size = 16
        unlabeled_frac = 0.0
        load_checkpoint = True
        num_epochs = 1

    # -- init loss
    criterion = torch.nn.CrossEntropyLoss()

    # -- make train data transforms and data loaders/samples
    transform, init_transform = make_transforms(dataset_name=dataset_name,
                                                subset_path=subset_path,
                                                unlabeled_frac=unlabeled_frac,
                                                training=training,
                                                crop_scale=crop_scale,
                                                split_seed=data_seed,
                                                basic_augmentations=True,
                                                normalize=normalize)
    (data_loader, dist_sampler) = init_data(dataset_name=dataset_name,
                                            transform=transform,
                                            init_transform=init_transform,
                                            u_batch_size=None,
                                            s_batch_size=batch_size,
                                            classes_per_batch=None,
                                            world_size=world_size,
                                            rank=rank,
                                            root_path=root_path,
                                            image_folder=image_folder,
                                            training=training,
                                            copy_data=copy_data)

    ipe = len(data_loader)
    logger.info(f'initialized data-loader (ipe {ipe})')

    # -- make val data transforms and data loaders/samples
    val_transform, val_init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=-1,
        training=True,
        basic_augmentations=True,
        force_center_crop=True,
        normalize=normalize)
    (val_data_loader,
     val_dist_sampler) = init_data(dataset_name=dataset_name,
                                   transform=val_transform,
                                   init_transform=val_init_transform,
                                   u_batch_size=None,
                                   s_batch_size=batch_size,
                                   classes_per_batch=None,
                                   world_size=1,
                                   rank=0,
                                   root_path=root_path,
                                   image_folder=image_folder,
                                   training=True,
                                   copy_data=copy_data)
    logger.info(f'initialized val data-loader (ipe {len(val_data_loader)})')

    # -- init model and optimizer
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
    encoder, optimizer, scheduler = init_model(
        device=device,
        device_str=args['meta']['device'],
        num_classes=num_classes,
        training=training,
        use_fp16=use_fp16,
        r_enc_path=r_enc_path,
        iterations_per_epoch=ipe,
        world_size=world_size,
        ref_lr=ref_lr,
        weight_decay=wd,
        use_lars=use_lars,
        zero_init=zero_init,
        num_epochs=num_epochs,
        model_name=model_name)

    best_acc = None
    start_epoch = 0
    # -- load checkpoint
    if not training or load_checkpoint:
        encoder, optimizer, scheduler, start_epoch, best_acc = load_from_path(
            r_path=w_enc_path,
            encoder=encoder,
            opt=optimizer,
            sched=scheduler,
            scaler=scaler,
            device_str=args['meta']['device'],
            use_fp16=use_fp16)
    if not training:
        logger.info('putting model in eval mode')
        encoder.eval()
        logger.info(
            sum(p.numel() for n, p in encoder.named_parameters()
                if p.requires_grad and ('fc' not in n)))
        start_epoch = 0

    for epoch in range(start_epoch, num_epochs):

        def train_step():
            # -- update distributed-data-loader epoch
            dist_sampler.set_epoch(epoch)
            top1_correct, top5_correct, total = 0, 0, 0
            for i, data in enumerate(data_loader):
                with torch.cuda.amp.autocast(enabled=use_fp16):
                    inputs, labels = data[0].to(device), data[1].to(device)
                    outputs = encoder(inputs)
                    loss = criterion(outputs, labels)
                total += inputs.shape[0]
                top5_correct += float(
                    outputs.topk(5,
                                 dim=1).indices.eq(labels.unsqueeze(1)).sum())
                top1_correct += float(
                    outputs.max(dim=1).indices.eq(labels).sum())
                top1_acc = 100. * top1_correct / total
                top5_acc = 100. * top5_correct / total
                if training:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                if i % log_freq == 0:
                    logger.info('[%d, %5d] %.3f%% %.3f%% (loss: %.3f)' %
                                (epoch + 1, i, top1_acc, top5_acc, loss))
            return 100. * top1_correct / total

        def val_step():
            val_encoder = copy.deepcopy(encoder).eval()
            top1_correct, total = 0, 0
            for i, data in enumerate(val_data_loader):
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = val_encoder(inputs)
                total += inputs.shape[0]
                top1_correct += float(
                    outputs.max(dim=1).indices.eq(labels).sum())
                top1_acc = 100. * top1_correct / total

            logger.info('[%d, %5d] %.3f%%' % (epoch + 1, i, top1_acc))
            return 100. * top1_correct / total

        train_top1 = 0.
        train_top1 = train_step()
        with torch.no_grad():
            val_top1 = val_step()

        log_str = 'train:' if training else 'test:'
        logger.info('[%d] (%s: %.3f%%) (val: %.3f%%)' %
                    (epoch + 1, log_str, train_top1, val_top1))

        # -- logging/checkpointing
        if training and (rank == 0) and ((best_acc is None) or
                                         (best_acc < val_top1)):
            best_acc = val_top1
            save_dict = {
                'encoder': encoder.state_dict(),
                'opt': optimizer.state_dict(),
                'sched': scheduler.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'world_size': world_size,
                'best_top1_acc': best_acc,
                'batch_size': batch_size,
                'lr': ref_lr,
                'amp': scaler.state_dict()
            }
            torch.save(save_dict, w_enc_path)

    return train_top1, val_top1
示例#4
0
def main(args):

    # -- META
    model_name = args['meta']['model_name']
    load_checkpoint = args['meta']['load_checkpoint']
    copy_data = args['meta']['copy_data']
    output_dim = args['meta']['output_dim']
    use_pred_head = args['meta']['use_pred_head']
    use_fp16 = args['meta']['use_fp16']
    device = torch.device(args['meta']['device'])
    torch.cuda.set_device(device)

    # -- DATA
    unlabeled_frac = args['data']['unlabeled_frac']
    label_smoothing = args['data']['label_smoothing']
    normalize = args['data']['normalize']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    dataset_name = args['data']['dataset']
    subset_path = args['data']['subset_path']
    unique_classes = args['data']['unique_classes_per_rank']
    data_seed = args['data']['data_seed']

    # -- CRITERTION
    classes_per_batch = args['criterion']['classes_per_batch']
    supervised_views = args['criterion']['supervised_views']
    batch_size = args['criterion']['supervised_batch_size']
    temperature = args['criterion']['temperature']

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    num_epochs = args['optimization']['epochs']
    use_lars = args['optimization']['use_lars']
    warmup = args['optimization']['warmup']
    start_lr = args['optimization']['start_lr']
    ref_lr = args['optimization']['lr']
    final_lr = args['optimization']['final_lr']
    momentum = args['optimization']['momentum']
    nesterov = args['optimization']['nesterov']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    r_file_enc = args['logging']['pretrain_path']

    # -- log/checkpointing paths
    r_enc_path = os.path.join(folder, r_file_enc)
    w_enc_path = os.path.join(folder, f'{tag}-fine-tune-SNN.pth.tar')

    # -- init distributed
    world_size, rank = init_distributed()
    logger.info(f'initialized rank/world-size: {rank}/{world_size}')

    # -- init loss
    suncet = init_suncet_loss(num_classes=classes_per_batch,
                              batch_size=batch_size * supervised_views,
                              world_size=world_size,
                              rank=rank,
                              temperature=temperature,
                              device=device)
    labels_matrix = make_labels_matrix(num_classes=classes_per_batch,
                                       s_batch_size=batch_size,
                                       world_size=world_size,
                                       device=device,
                                       unique_classes=unique_classes,
                                       smoothing=label_smoothing)

    # -- make data transforms
    transform, init_transform = make_transforms(dataset_name=dataset_name,
                                                subset_path=subset_path,
                                                unlabeled_frac=unlabeled_frac,
                                                training=True,
                                                split_seed=data_seed,
                                                basic_augmentations=True,
                                                normalize=normalize)
    (data_loader,
     dist_sampler) = init_data(dataset_name=dataset_name,
                               transform=transform,
                               init_transform=init_transform,
                               supervised_views=supervised_views,
                               u_batch_size=None,
                               stratify=True,
                               s_batch_size=batch_size,
                               classes_per_batch=classes_per_batch,
                               unique_classes=unique_classes,
                               world_size=world_size,
                               rank=rank,
                               root_path=root_path,
                               image_folder=image_folder,
                               training=True,
                               copy_data=copy_data)

    # -- rough estimate of labeled imgs per class used to set the number of
    #    fine-tuning iterations
    imgs_per_class = int(
        1300 * (1. - unlabeled_frac)) if 'imagenet' in dataset_name else int(
            5000 * (1. - unlabeled_frac))
    dist_sampler.set_inner_epochs(imgs_per_class // batch_size)

    ipe = len(data_loader)
    logger.info(f'initialized data-loader (ipe {ipe})')

    # -- init model and optimizer
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
    encoder, optimizer, scheduler = init_model(device=device,
                                               training=True,
                                               r_enc_path=r_enc_path,
                                               iterations_per_epoch=ipe,
                                               world_size=world_size,
                                               start_lr=start_lr,
                                               ref_lr=ref_lr,
                                               num_epochs=num_epochs,
                                               output_dim=output_dim,
                                               model_name=model_name,
                                               warmup_epochs=warmup,
                                               use_pred_head=use_pred_head,
                                               use_fp16=use_fp16,
                                               wd=wd,
                                               final_lr=final_lr,
                                               momentum=momentum,
                                               nesterov=nesterov,
                                               use_lars=use_lars)

    best_acc, val_top1 = None, None
    start_epoch = 0
    # -- load checkpoint
    if load_checkpoint:
        encoder, optimizer, scaler, scheduler, start_epoch, best_acc = load_from_path(
            r_path=w_enc_path,
            encoder=encoder,
            opt=optimizer,
            scaler=scaler,
            sched=scheduler,
            device=device,
            use_fp16=use_fp16,
            ckp=True)

    for epoch in range(start_epoch, num_epochs):

        def train_step():
            # -- update distributed-data-loader epoch
            dist_sampler.set_epoch(epoch)

            for i, data in enumerate(data_loader):
                imgs = torch.cat([s.to(device) for s in data[:-1]], 0)
                labels = torch.cat(
                    [labels_matrix for _ in range(supervised_views)])
                with torch.cuda.amp.autocast(enabled=use_fp16):
                    optimizer.zero_grad()
                    z = encoder(imgs)
                    loss = suncet(z, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                if i % log_freq == 0:
                    logger.info('[%d, %5d] (loss: %.3f)' %
                                (epoch + 1, i, loss))

        with torch.no_grad():
            with nostdout():
                val_top1, _ = val_run(pretrained=copy.deepcopy(encoder),
                                      subset_path=subset_path,
                                      unlabeled_frac=unlabeled_frac,
                                      dataset_name=dataset_name,
                                      root_path=root_path,
                                      image_folder=image_folder,
                                      use_pred=use_pred_head,
                                      normalize=normalize,
                                      split_seed=data_seed)
        logger.info('[%d] (val: %.3f%%)' % (epoch + 1, val_top1))
        train_step()

        # -- logging/checkpointing
        if (rank == 0) and ((best_acc is None) or (best_acc < val_top1)):
            best_acc = val_top1
            save_dict = {
                'encoder': encoder.state_dict(),
                'opt': optimizer.state_dict(),
                'sched': scheduler.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'world_size': world_size,
                'batch_size': batch_size,
                'best_top1_acc': best_acc,
                'lr': ref_lr,
                'amp': scaler.state_dict()
            }
            torch.save(save_dict, w_enc_path)

    logger.info('[%d] (best-val: %.3f%%)' % (epoch + 1, best_acc))