Пример #1
0
def train(train_loader, model, criterion, optimizer, epoch, args, logger,
          time_logger):
    ''' -------------------------averageMeter 선언.-----------------------------'''
    batch_time = util.AverageMeter('Time', ':6.3f')
    data_time = util.AverageMeter('Data', ':6.3f')
    losses = util.AverageMeter('Loss', ':.4f')
    top1 = util.AverageMeter('Acc@1', ':6.2f')
    top5 = util.AverageMeter('Acc@5', ':6.2f')
    ''' -------------------------출력 progress 선언.-----------------------------'''
    progress = util.ProgressMeter(len(train_loader),
                                  [batch_time, data_time, losses, top1, top5],
                                  prefix="Epoch: [{}]".format(epoch))
    ''' -------------------------학습 시작.-----------------------------'''
    model.train()
    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        output = model(images)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Gradient averaging
        average_gradients(model)
        ''' -------------------------이미지넷 top1, top5 accuracy----------------------------'''
        acc1, acc5, correct = util.accuracy(output, target, topk=(1, 5))
        ''' -------------------------각 GPU log 합쳐주기-----------------------------'''
        reduced_loss = reduce_tensor(loss.data)
        reduced_top1 = reduce_tensor(acc1[0].data)
        reduced_top5 = reduce_tensor(acc5[0].data)
        ''' ------------------------- averageMeter에 업데이트 -----------------------------'''
        losses.update(reduced_loss.item(), images.size(0))
        top1.update(reduced_top1.item(), images.size(0))
        top5.update(reduced_top5.item(), images.size(0))

        batch_time.update(time.time() - end)
        end = time.time()
        ''' ------------------------- gpu 하나로만 출력하기. (rank == 0 : 0번 gpu에서만 출력하도록.)-----------------------------'''
        if dist.get_rank() == 0:
            if i % args.print_freq == 0:
                progress.display(i)
    ''' ------------------------- logger 에 업데이트-----------------------------'''
    if dist.get_rank() == 0:
        logger.write([epoch, losses.avg, top1.avg, top5.avg])
        time_logger.write([epoch, batch_time.avg, data_time.avg])
Пример #2
0
def validate(val_loader, model, criterion, epoch, args, logger, time_logger):
    batch_time = util.AverageMeter('Time', ':6.3f')
    data_time = util.AverageMeter('Data', ':6.3f')
    losses = util.AverageMeter('Loss', ':.4f')
    top1 = util.AverageMeter('Acc@1', ':6.2f')
    top5 = util.AverageMeter('Acc@5', ':6.2f')
    progress = util.ProgressMeter(len(val_loader),
                                  [batch_time, data_time, losses, top1, top5],
                                  prefix='Test: ')

    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            data_time.update(time.time() - end)

            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            output = model(images)
            loss = criterion(output, target)

            acc1, acc5, correct = util.accuracy(output, target, topk=(1, 5))

            reduced_loss = reduce_tensor(loss.data)
            reduced_top1 = reduce_tensor(acc1[0].data)
            reduced_top5 = reduce_tensor(acc5[0].data)

            losses.update(reduced_loss.item(), images.size(0))
            top1.update(reduced_top1.item(), images.size(0))
            top5.update(reduced_top5.item(), images.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if dist.get_rank() == 0:
                if i % args.print_freq == 0:
                    progress.display(i)

        if dist.get_rank() == 0:
            print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
                top1=top1, top5=top5))

    if dist.get_rank() == 0:
        logger.write([epoch, losses.avg, top1.avg, top5.avg])
        time_logger.write([epoch, batch_time.avg, data_time.avg])

    return top1.avg
Пример #3
0
def validate(val_loader, model, loss_fn, epoch, model_history, trn_loss, args):
    batch_time = util.AverageMeter('Time', ':6.3f')
    losses = util.AverageMeter('Loss', ':.4e')
    top1 = util.AverageMeter('Acc@1', ':6.2f')
    top5 = util.AverageMeter('Acc@5', ':6.2f')
    progress = util.ProgressMeter(len(val_loader),
                                  [batch_time, losses, top1, top5],
                                  prefix='Test: ')
    val_loss = []
    val_true = []
    val_pred = []

    # switch to evaluate mode
    model.eval()
    end = time.time()

    prefetcher = util.data_prefetcher(val_loader)
    input, target = prefetcher.next()
    batch_idx = 0
    while input is not None:
        batch_idx += 1

        # compute output
        with torch.no_grad():
            logits = model(input)
            grapheme = logits[:, :168]
            vowel = logits[:, 168:179]
            cons = logits[:, 179:]

            loss= 0.5* loss_fn(grapheme, target[:,0]) + 0.25*loss_fn(vowel, target[:,1]) + \
            0.25*loss_fn(vowel, target[:,2])
            val_loss.append(loss.item())

            grapheme = grapheme.cpu().argmax(dim=1).data.numpy()
            vowel = vowel.cpu().argmax(dim=1).data.numpy()
            cons = cons.cpu().argmax(dim=1).data.numpy()

            val_true.append(target.cpu().numpy())
            val_pred.append(np.stack([grapheme, vowel, cons], axis=1))

        # measure accuracy and record loss
        prec1, prec5 = util.accuracy(logits, target, topk=(1, 5))

        if args.multigpus_distributed:
            reduced_loss = dis_util.reduce_tensor(loss.data, args)
            prec1 = dis_util.reduce_tensor(prec1, args)
            prec5 = dis_util.reduce_tensor(prec5, args)
        else:
            reduced_loss = loss.data

        losses.update(to_python_float(reduced_loss), input.size(0))
        top1.update(to_python_float(prec1), input.size(0))
        top5.update(to_python_float(prec5), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # TODO:  Change timings to mirror train().
        if args.current_gpu == 0 and batch_idx % args.log_interval == 0:
            print('Test: [{0}/{1}]  '
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                  'Speed {2:.3f} ({3:.3f})  '
                  'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})  '
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      batch_idx,
                      len(val_loader),
                      args.world_size * args.batch_size / batch_time.val,
                      args.world_size * args.batch_size / batch_time.avg,
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))
            model_history['val_epoch'].append(epoch)
            model_history['val_batch_idx'].append(batch_idx)
            model_history['val_batch_time'].append(batch_time.val)
            model_history['val_losses'].append(losses.val)
            model_history['val_top1'].append(top1.val)
            model_history['val_top5'].append(top5.val)
        input, target = prefetcher.next()

    val_true_concat = np.concatenate(val_true)
    val_pred_concat = np.concatenate(val_pred)
    val_loss_mean = np.mean(val_loss)
    trn_loss_mean = np.mean(trn_loss)

    score_g = recall_score(val_true_concat[:, 0],
                           val_pred_concat[:, 0],
                           average='macro')
    score_v = recall_score(val_true_concat[:, 1],
                           val_pred_concat[:, 1],
                           average='macro')
    score_c = recall_score(val_true_concat[:, 2],
                           val_pred_concat[:, 2],
                           average='macro')
    final_score = np.average([score_g, score_v, score_c], weights=[2, 1, 1])

    if args.current_gpu == 0:
        # Printing vital information
        s = f'[Epoch {epoch}] ' \
        f'trn_loss: {trn_loss_mean:.4f}, vld_loss: {val_loss_mean:.4f}, score: {final_score:.4f}, ' \
        f'score_each: [{score_g:.4f}, {score_v:.4f}, {score_c:.4f}]'
        print(s)

    print('  Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1,
                                                                 top5=top5))
    model_history['val_avg_epoch'].append(epoch)
    model_history['val_avg_batch_time'].append(batch_time.avg)
    model_history['val_avg_losses'].append(losses.avg)
    model_history['val_avg_top1'].append(top1.avg)
    model_history['val_avg_top5'].append(top5.avg)
    return top1.avg
Пример #4
0
def train(current_gpu, args):
    best_acc1 = -1
    model_history = {}
    model_history = util.init_modelhistory(model_history)
    train_start = time.time()

    ## choose model from pytorch model_zoo
    model = util.torch_model(args.model_name, pretrained=True)
    loss_fn = nn.CrossEntropyLoss().cuda()

    ## distributed_setting
    model, args = dis_util.dist_setting(current_gpu, model, loss_fn, args)

    ## CuDNN library will benchmark several algorithms and pick that which it found to be fastest
    cudnn.benchmark = False if args.seed else True

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.apex:
        model, optimizer = dis_util.apex_init(model, optimizer, args)


#     args.collate_fn = partial(dis_util.fast_collate, memory_format=args.memory_format)

    args = _get_images(args, data_type='train')
    train_loader, train_sampler = _get_train_data_loader(args, **args.kwargs)
    test_loader = _get_test_data_loader(args, **args.kwargs)

    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
        len(train_loader.sampler), len(train_loader.dataset),
        100. * len(train_loader.sampler) / len(train_loader.dataset)))

    logger.info("Processes {}/{} ({:.0f}%) of test data".format(
        len(test_loader.sampler), len(test_loader.dataset),
        100. * len(test_loader.sampler) / len(test_loader.dataset)))

    for epoch in range(1, args.num_epochs + 1):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(train_loader), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        trn_loss = []
        model.train()
        end = time.time()
        running_loss = 0.0
        ## Set epoch count for DistributedSampler
        if args.multigpus_distributed:
            train_sampler.set_epoch(epoch)

        prefetcher = util.data_prefetcher(train_loader)
        input, target = prefetcher.next()
        batch_idx = 0
        while input is not None:

            batch_idx += 1

            if args.prof >= 0 and batch_idx == args.prof:
                print("Profiling begun at iteration {}".format(batch_idx))
                torch.cuda.cudart().cudaProfilerStart()

            if args.prof >= 0:
                torch.cuda.nvtx.range_push(
                    "Body of iteration {}".format(batch_idx))

            util.adjust_learning_rate(optimizer, epoch, batch_idx,
                                      len(train_loader), args)

            ##### DATA Processing #####
            targets_gra = target[:, 0]
            targets_vow = target[:, 1]
            targets_con = target[:, 2]

            # 50%의 확률로 원본 데이터 그대로 사용
            if np.random.rand() < 0.5:
                logits = model(input)
                grapheme = logits[:, :168]
                vowel = logits[:, 168:179]
                cons = logits[:, 179:]

                loss1 = loss_fn(grapheme, targets_gra)
                loss2 = loss_fn(vowel, targets_vow)
                loss3 = loss_fn(cons, targets_con)

            else:

                lam = np.random.beta(1.0, 1.0)
                rand_index = torch.randperm(input.size()[0])
                shuffled_targets_gra = targets_gra[rand_index]
                shuffled_targets_vow = targets_vow[rand_index]
                shuffled_targets_con = targets_con[rand_index]

                bbx1, bby1, bbx2, bby2 = _rand_bbox(input.size(), lam)
                input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :,
                                                          bbx1:bbx2, bby1:bby2]
                # 픽셀 비율과 정확히 일치하도록 lambda 파라메터 조정
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                           (input.size()[-1] * input.size()[-2]))

                logits = model(input)
                grapheme = logits[:, :168]
                vowel = logits[:, 168:179]
                cons = logits[:, 179:]

                loss1 = loss_fn(grapheme, targets_gra) * lam + loss_fn(
                    grapheme, shuffled_targets_gra) * (1. - lam)
                loss2 = loss_fn(vowel, targets_vow) * lam + loss_fn(
                    vowel, shuffled_targets_vow) * (1. - lam)
                loss3 = loss_fn(cons, targets_con) * lam + loss_fn(
                    cons, shuffled_targets_con) * (1. - lam)

            loss = 0.5 * loss1 + 0.25 * loss2 + 0.25 * loss3
            trn_loss.append(loss.item())
            running_loss += loss.item()

            #########################################################

            # compute gradient and do SGD step
            optimizer.zero_grad()

            if args.apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            # Printing vital information
            if (batch_idx + 1) % (args.log_interval) == 0:
                s = f'[Epoch {epoch} Batch {batch_idx+1}/{len(train_loader)}] ' \
                f'loss: {running_loss / args.log_interval:.4f}'
                print(s)
                running_loss = 0

            if True or batch_idx % args.log_interval == 0:
                # Every log_interval iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                prec1, prec5 = util.accuracy(logits, target, topk=(1, 5))

                # Average loss and accuracy across processes for logging
                if args.multigpus_distributed:
                    reduced_loss = dis_util.reduce_tensor(loss.data, args)
                    prec1 = dis_util.reduce_tensor(prec1, args)
                    prec5 = dis_util.reduce_tensor(prec5, args)
                else:
                    reduced_loss = loss.data

                # to_python_float incurs a host<->device sync
                losses.update(to_python_float(reduced_loss), input.size(0))
                top1.update(to_python_float(prec1), input.size(0))
                top5.update(to_python_float(prec5), input.size(0))

                ## Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - end) / args.log_interval)
                end = time.time()

                if current_gpu == 0:
                    print(
                        'Epoch: [{0}][{1}/{2}]  '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                        'Speed {3:.3f} ({4:.3f})  '
                        'Loss {loss.val:.10f} ({loss.avg:.4f})  '
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})  '
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                            epoch,
                            batch_idx,
                            len(train_loader),
                            args.world_size * args.batch_size / batch_time.val,
                            args.world_size * args.batch_size / batch_time.avg,
                            batch_time=batch_time,
                            loss=losses,
                            top1=top1,
                            top5=top5))
                    model_history['epoch'].append(epoch)
                    model_history['batch_idx'].append(batch_idx)
                    model_history['batch_time'].append(batch_time.val)
                    model_history['losses'].append(losses.val)
                    model_history['top1'].append(top1.val)
                    model_history['top5'].append(top5.val)

            input, target = prefetcher.next()

        acc1 = validate(test_loader, model, loss_fn, epoch, model_history,
                        trn_loss, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multigpus_distributed or (args.multigpus_distributed and
                                              args.rank % args.num_gpus == 0):
            util.save_history(
                os.path.join(args.output_data_dir, 'model_history.p'),
                model_history)

            util.save_model(
                {
                    'epoch': epoch + 1,
                    'model_name': args.model_name,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    #                 'class_to_idx' : train_loader.dataset.class_to_idx,
                },
                is_best,
                args.model_dir)
Пример #5
0
        pass        
    
    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))
        
        vae.train()
        start = time.time()
        
        for i, (images, _) in enumerate(dl):
            images = images.cuda()
            
            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
            else:
                loss, recons = distr_vae(
                    images,
Пример #6
0
def validate(val_loader, model, criterion, epoch, model_history, args):
    batch_time = util.AverageMeter('Time', ':6.3f')
    losses = util.AverageMeter('Loss', ':.4e')
    top1 = util.AverageMeter('Acc@1', ':6.2f')
    top5 = util.AverageMeter('Acc@5', ':6.2f')
    progress = util.ProgressMeter(len(val_loader),
                                  [batch_time, losses, top1, top5],
                                  prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    end = time.time()

#     print("**** validate *****")
    test_losses = []
    for batch_idx, (input, target) in enumerate((val_loader)):
        input = input.to(args.device)
        target = target.to(args.device)

        batch_idx += 1
        # compute output
        with torch.no_grad():
            if args.model_parallel:
                output, loss = dis_util.test_step(model, criterion, input,
                                                  target)
                loss = loss.reduce_mean()
                test_losses.append(loss)
            else:
                output = model(input)
                loss = criterion(output, target)

        # measure accuracy and record loss
        if args.model_parallel:
            output = torch.cat(output.outputs)

        prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

        losses.update(util.to_python_float(loss), input.size(0))
        top1.update(util.to_python_float(prec1), input.size(0))
        top5.update(util.to_python_float(prec5), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #         print("Validation args.rank : {}".format(args.rank))
        # TODO:  Change timings to mirror train().
        if args.rank == 0:
            print('Test: [{0}/{1}]  '
                  'Test_Time={batch_time.val:.3f}:({batch_time.avg:.3f}), '
                  'Test_Speed={2:.3f}:({3:.3f}), '
                  'Test_Loss={loss.val:.4f}:({loss.avg:.4f}), '
                  'Test_Prec@1={top1.val:.3f}:({top1.avg:.3f}), '
                  'Test_Prec@5={top5.val:.3f}:({top5.avg:.3f})'.format(
                      batch_idx,
                      len(val_loader),
                      args.world_size * args.batch_size / batch_time.val,
                      args.world_size * args.batch_size / batch_time.avg,
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))
            model_history['val_epoch'].append(epoch)
            model_history['val_batch_idx'].append(batch_idx)
            model_history['val_batch_time'].append(batch_time.val)
            model_history['val_losses'].append(losses.val)
            model_history['val_top1'].append(top1.val)
            model_history['val_top5'].append(top5.val)

    print('Prec@1={top1.avg:.3f}, Prec@5={top5.avg:.3f}'.format(top1=top1,
                                                                 top5=top5))
    model_history['val_avg_epoch'].append(epoch)
    model_history['val_avg_batch_time'].append(batch_time.avg)
    model_history['val_avg_losses'].append(losses.avg)
    model_history['val_avg_top1'].append(top1.avg)
    model_history['val_avg_top5'].append(top5.avg)

    if args.assert_losses:
        dist_util.smp_lossgather(losses.avg, args)
    return top1.avg
Пример #7
0
def train(local_rank, args):
    best_acc1 = -1
    model_history = {}
    model_history = util.init_modelhistory(model_history)
    train_start = time.time()

    if local_rank is not None:
        args.local_rank = local_rank
        
    # distributed_setting
    if args.multigpus_distributed:
        args = dis_util.dist_setting(args)



    # choose model from pytorch model_zoo
    model = util.torch_model(
        args.model_name,
        num_classes=args.num_classes,
        pretrained=True,
        local_rank=args.local_rank,
        model_parallel=args.model_parallel)  # 1000 resnext101_32x8d
    criterion = nn.CrossEntropyLoss().cuda()

    model, args = dis_util.dist_model(model, args)

    # CuDNN library will benchmark several algorithms and pick that which it found to be fastest
    cudnn.benchmark = False if args.seed else True

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.apex:
        model, optimizer, args = dis_util.apex_init(model, optimizer, args)
    elif args.model_parallel:
        model, optimizer, args = dis_util.smp_init(model, optimizer, args)
    elif args.data_parallel:
        model, optimizer, args = dis_util.sdp_init(model, optimizer, args)

    train_loader, train_sampler = _get_train_data_loader(args, **args.kwargs)

    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
        len(train_loader.sampler), len(train_loader.dataset),
        100. * len(train_loader.sampler) / len(train_loader.dataset)))

    test_loader = _get_test_data_loader(args, **args.kwargs)

    #     if args.rank == 0:
    logger.info("Processes {}/{} ({:.0f}%) of test data".format(
        len(test_loader.sampler), len(test_loader.dataset),
        100. * len(test_loader.sampler) / len(test_loader.dataset)))

    print(" local_rank : {}, local_batch_size : {}".format(
        local_rank, args.batch_size))

    for epoch in range(1, args.num_epochs + 1):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(train_loader), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        model.train()
        end = time.time()

        # Set epoch count for DistributedSampler
        if args.multigpus_distributed and not args.model_parallel:
            train_sampler.set_epoch(epoch)

        for batch_idx, (input, target) in enumerate(train_loader):
            input = input.to(args.device)
            target = target.to(args.device)
            batch_idx += 1

            if args.model_parallel:
                print("** smp_train_step **")
                output, loss = dis_util.train_step(model, criterion, input,
                                                   target, args.scaler, args)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()

                print("reduce_mean : {}".format(loss))
            else:
                #                 print("** not model_parallel")
                output = model(input)
                loss = criterion(output, target)

            # compute gradient and do SGD step
            optimizer.zero_grad()

            if args.apex:
                dis_util.apex_loss(loss, optimizer)
            elif not args.model_parallel:
                loss.backward()

            optimizer.step()

            if args.rank == 0:
                #             if args.rank == 0 and batch_idx % args.log_interval == 1:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                if args.model_parallel:
                    output = torch.cat(output.outputs)

                # Measure accuracy
                prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), input.size(0))
                top1.update(util.to_python_float(prec1), input.size(0))
                top5.update(util.to_python_float(prec5), input.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - end) / args.log_interval)
                end = time.time()

                #                 if args.rank == 0:
                print('Epoch: [{0}][{1}/{2}] '
                      'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                      'Train_Speed={3:.3f} ({4:.3f}), '
                      'Train_Loss={loss.val:.10f}:({loss.avg:.4f}), '
                      'Train_Prec@1={top1.val:.3f}:({top1.avg:.3f}), '
                      'Train_Prec@5={top5.val:.3f}:({top5.avg:.3f})'.format(
                          epoch,
                          batch_idx,
                          len(train_loader),
                          args.world_size * args.batch_size / batch_time.val,
                          args.world_size * args.batch_size / batch_time.avg,
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))

        acc1 = validate(test_loader, model, criterion, epoch, model_history,
                        args)

        is_best = False

        if args.rank == 0:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        if not args.multigpus_distributed or (args.multigpus_distributed
                                              and not args.model_parallel
                                              and args.rank == 0):
            model_history['epoch'].append(epoch)
            model_history['batch_idx'].append(batch_idx)
            model_history['batch_time'].append(batch_time.val)
            model_history['losses'].append(losses.val)
            model_history['top1'].append(top1.val)
            model_history['top5'].append(top5.val)

            util.save_history(
                os.path.join(args.output_data_dir, 'model_history.p'),
                model_history)
            util.save_model(
                {
                    'epoch': epoch + 1,
                    'model_name': args.model_name,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'class_to_idx': train_loader.dataset.class_to_idx,
                }, is_best, args)
        elif args.model_parallel:
            if args.rank == 0:
                util.save_history(
                    os.path.join(args.output_data_dir, 'model_history.p'),
                    model_history)
            dis_util.smp_savemodel(model, optimizer, is_best, args)
Пример #8
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")