예제 #1
0
def validate(valid_loader, model, epoch, cur_step, writer, logger, config):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    model.eval()
    device = torch.device("cuda")
    criterion = nn.CrossEntropyLoss().to(device)

    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            N = X.size(0)

            logits, _ = model(X)
            loss = criterion(logits, y)

            prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))

            if config.distributed:
                reduced_loss = utils.reduce_tensor(loss.data,
                                                   config.world_size)
                prec1 = utils.reduce_tensor(prec1, config.world_size)
                prec5 = utils.reduce_tensor(prec5, config.world_size)
            else:
                reduced_loss = loss.data

            losses.update(reduced_loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            torch.cuda.synchronize()
            step_num = len(valid_loader)

            if (step % config.print_freq == 0
                    or step == step_num - 1) and config.local_rank == 0:
                logger.info(
                    "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        config.epochs,
                        step,
                        step_num,
                        losses=losses,
                        top1=top1,
                        top5=top5))

    if config.local_rank == 0:
        writer.add_scalar('val/loss', losses.avg, cur_step)
        writer.add_scalar('val/top1', top1.avg, cur_step)
        writer.add_scalar('val/top5', top5.avg, cur_step)

        logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format(
            epoch + 1, config.epochs, top1.avg))

    return top1.avg, top5.avg
예제 #2
0
def validate(testloader, model, test_size, local_rank):
    if local_rank <= 0:
        logging.info('Start evaluation...')
    model.eval()
    ave_loss = AverageMeter()
    with torch.no_grad():
        iterator = tqdm(testloader, ascii=True) if local_rank <= 0 else testloader
        for batch in iterator:
            def handle_batch():
                a, fg, bg, _, _ = batch      # [B, 3, 3 or 1, H, W]
                out = model(a, fg, bg)
                L_alpha = out[0].mean()
                L_comp = out[1].mean()
                L_grad = out[2].mean()
                #L_temp = out[3].mean()
                #loss['L_total'] = 0.5 * loss['L_alpha'] + 0.5 * loss['L_comp'] + loss['L_grad'] + 0.5 * loss['L_temp']
                #loss['L_total'] = loss['L_alpha'] + loss['L_comp'] + loss['L_grad'] + loss['L_temp']
                loss = L_alpha + L_comp + L_grad
                return loss.detach()

            loss = handle_batch()
            reduced_loss = reduce_tensor(loss)

            ave_loss.update(reduced_loss.item())
    if local_rank <= 0:
        logging.info('Validation loss: {:.6f}'.format(ave_loss.average()))
    return ave_loss.average()
예제 #3
0
def train(train_dataprovider, val_dataprovider, optimizer, scheduler, model, archloader, criterion, args, val_iters, seed, writer=None):
    objs, top1 = AvgrageMeter(), AvgrageMeter()

    for p in model.parameters():
        p.grad = torch.zeros_like(p)

    for step in range(args.total_iters):
        model.train()
        t0 = time.time()
        image, target = train_dataprovider.next()
        datatime = time.time() - t0
        n = image.size(0)
        optimizer.zero_grad()
        image = Variable(image, requires_grad=False).cuda(args.gpu)
        target = Variable(target, requires_grad=False).cuda(args.gpu)

        # Fair Sampling
        fair_arc_list = archloader.generate_niu_fair_batch()

        for arc in fair_arc_list:
            logits = model(image, archloader.convert_list_arc_str(arc))
            loss = criterion(logits, target)
            loss_reduce = reduce_tensor(loss, 0, args.world_size)
            loss.backward()

        nn.utils.clip_grad_value_(model.parameters(), args.grad_clip)
        optimizer.step()
        scheduler.step()

        prec1, _ = accuracy(logits, target, topk=(1, 5))
        objs.update(loss_reduce.data.item(), n)
        top1.update(prec1.data.item(), n)

        if step % args.report_freq == 0 and args.local_rank == 0:
            now = time.strftime('%Y-%m-%d %H:%M:%S',
                                time.localtime(time.time()))
            print('{} |=> train: {} / {}, lr={}, loss={:.2f}, acc={:.2f}, datatime={:.2f}, seed={}'
                  .format(now, step, args.total_iters, scheduler.get_lr()[0], objs.avg, top1.avg, float(datatime), seed))

        if args.local_rank == 0 and step % 5 == 0 and writer is not None:
            writer.add_scalar("Train/loss", objs.avg, step)
            writer.add_scalar("Train/acc1", top1.avg, step)

        if args.local_rank == 0 and step % args.report_freq == 0:
            # model

            top1_val, objs_val = infer(train_dataprovider, val_dataprovider, model.module, criterion,
                                       fair_arc_list, val_iters, archloader)

            if writer is not None:
                writer.add_scalar("Val/loss", objs_val, step)
                writer.add_scalar("Val/acc1", top1_val, step)

            save_checkpoint(
                {'state_dict': model.state_dict(), }, step, args.exp)
예제 #4
0
    def _validate_loop(self, batch, iteration, val_objs, val_args):
        start_time = time.time()

        # load validation image and ground truth labels
        image, target = val_objs.batch_fn(batch=batch,
                                          num_classes=val_args.num_classes,
                                          mode='val')

        image = image.cuda()
        target = target.cuda()
        target_squeezed = squeeze_one_hot(target)

        # forward pass and compute loss
        with torch.no_grad():
            output = val_objs.model(image)
            loss = self.criterion(output, target)

        val_objs.meters = compute_accuracy_dist(
            output=output,
            target_squeezed=target_squeezed,
            meters=val_objs.meters,
            world_size=self.context.world_size)

        reduced_loss = reduce_tensor(tensor=loss.data,
                                     world_size=self.context.world_size)

        val_objs.meters.get('losses').update(reduced_loss.item(),
                                             image.size(0))
        val_objs.meters.get('batch_time').update(time.time() - start_time)

        if (self.context.gpu_no != 0
                or iteration % self.args.util.print_freq != 0
                or iteration == 0):
            return

        # print intermediate results
        PrintCollection.print_val_batch_info(args=self.args,
                                             iteration=iteration,
                                             meters=val_objs.meters,
                                             val_len=val_args.len)
예제 #5
0
def main(config):
    save_path = config['save_path']
    epochs = config['epochs']
    os.environ['TORCH_HOME'] = config['torch_home']
    distributed = config['use_DDP']
    start_ep = 0
    start_cnt = 0

    # initialize model
    print("Initializing model...")
    if distributed:
        initialize_distributed(config)
    rank = config['rank']

    # map string name to class constructor
    model = get_model(config)
    model.apply(init_weights)
    if config['resume_ckpt'] is not None:
        # load weights from checkpoint
        state_dict = load_weights(config['resume_ckpt'])
        model.load_state_dict(state_dict)

    print("Moving model to GPU")
    model.cuda(torch.cuda.current_device())
    print("Setting up losses")

    if config['use_vgg']:
        criterionVGG = Vgg19PerceptualLoss(config['reduced_w'])
        criterionVGG.cuda()
        validationLoss = criterionVGG
    if config['use_gan']:
        use_sigmoid = config['no_lsgan']
        disc_input_channels = 3
        discriminator = MultiscaleDiscriminator(disc_input_channels,
                                                config['ndf'],
                                                config['n_layers_D'],
                                                'instance', use_sigmoid,
                                                config['num_D'], False, False)
        discriminator.apply(init_weights)
        if config['resume_ckpt_D'] is not None:
            # load weights from checkpoint
            print("Resuming discriminator from %s" % (config['resume_ckpt_D']))
            state_dict = load_weights(config['resume_ckpt_D'])
            discriminator.load_state_dict(state_dict)

        discriminator.cuda(torch.cuda.current_device())
        criterionGAN = GANLoss(use_lsgan=not config['no_lsgan'])
        criterionGAN.cuda()
        criterionFeat = nn.L1Loss().cuda()
    if config['use_l2']:
        criterionMSE = nn.MSELoss()
        criterionMSE.cuda()
        validationLoss = criterionMSE

    # initialize dataloader
    print("Setting up dataloaders...")
    train_dataloader, val_dataloader, train_sampler = setup_dataloaders(config)
    print("Done!")
    # run the training loop
    print("Initializing optimizers...")
    optimizer_G = optim.Adam(model.parameters(),
                             lr=config['learning_rate'],
                             weight_decay=config['weight_decay'])
    if config['resume_ckpt_opt_G'] is not None:
        optimizer_G_state_dict = torch.load(
            config['resume_ckpt_opt_G'],
            map_location=lambda storage, loc: storage)
        optimizer_G.load_state_dict(optimizer_G_state_dict)
    if config['use_gan']:
        optimizer_D = optim.Adam(discriminator.parameters(),
                                 lr=config['learning_rate'])
        if config['resume_ckpt_opt_D'] is not None:
            optimizer_D_state_dict = torch.load(
                config['resume_ckpt_opt_D'],
                map_location=lambda storage, loc: storage)
            optimizer_D.load_state_dict(optimizer_D_state_dict)

    print("Done!")

    if distributed:
        print("Moving model to DDP...")
        model = DDP(model)
        if config['use_gan']:
            discriminator = DDP(discriminator, delay_allreduce=True)
        print("Done!")

    tb_logger = None
    if rank == 0:
        tb_logdir = os.path.join(save_path, 'tbdir')
        if not os.path.exists(tb_logdir):
            os.makedirs(tb_logdir)
        tb_logger = SummaryWriter(tb_logdir)
        # run training
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        log_name = os.path.join(save_path, 'loss_log.txt')
        opt_name = os.path.join(save_path, 'opt.yaml')
        print(config)
        save_options(opt_name, config)
        log_handle = open(log_name, 'a')

    print("Starting training")
    cnt = start_cnt
    assert (config['use_warped'] or config['use_temporal'])

    for ep in range(start_ep, epochs):
        if train_sampler is not None:
            train_sampler.set_epoch(ep)

        for curr_batch in train_dataloader:
            optimizer_G.zero_grad()
            input_a = curr_batch['input_a'].cuda()
            target = curr_batch['target'].cuda()
            if config['use_warped'] and config['use_temporal']:
                input_a = torch.cat((input_a, input_a), 0)
                input_b = torch.cat((curr_batch['input_b'].cuda(),
                                     curr_batch['input_temporal'].cuda()), 0)
                target = torch.cat((target, target), 0)
            elif config['use_temporal']:
                input_b = curr_batch['input_temporal'].cuda()
            elif config['use_warped']:
                input_b = curr_batch['input_b'].cuda()

            output_dict = model(input_a, input_b)
            output_recon = output_dict['reconstruction']

            loss_vgg = loss_G_GAN = loss_G_feat = loss_l2 = 0
            if config['use_vgg']:
                loss_vgg = criterionVGG(output_recon,
                                        target) * config['vgg_lambda']
            if config['use_gan']:
                predicted_landmarks = output_dict['input_a_gauss_maps']
                # output_dict['reconstruction'] can be considered normalized
                loss_G_GAN, loss_D_real, loss_D_fake = apply_GAN_criterion(
                    output_recon, target, predicted_landmarks.detach(),
                    discriminator, criterionGAN)
                loss_D = (loss_D_fake + loss_D_real) * 0.5
            if config['use_l2']:
                loss_l2 = criterionMSE(output_recon,
                                       target) * config['l2_lambda']

            loss_G = loss_G_GAN + loss_G_feat + loss_vgg + loss_l2
            loss_G.backward()
            # grad_norm clipping
            if not config['no_grad_clip']:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer_G.step()
            if config['use_gan']:
                optimizer_D.zero_grad()
                loss_D.backward()
                # grad_norm clipping
                if not config['no_grad_clip']:
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                                   1.0)
                optimizer_D.step()

            if distributed:
                if config['use_vgg']:
                    loss_vgg = reduce_tensor(loss_vgg, config['world_size'])

            if rank == 0:
                if cnt % 10 == 0:
                    run_visualization(output_dict, output_recon, target,
                                      input_a, input_b, save_path, tb_logger,
                                      cnt)

                print_dict = {"learning_rate": get_learning_rate(optimizer_G)}
                if config['use_vgg']:
                    tb_logger.add_scalar('vgg.loss', loss_vgg, cnt)
                    print_dict['Loss_VGG'] = loss_vgg.data
                if config['use_gan']:
                    tb_logger.add_scalar('gan.loss', loss_G_GAN, cnt)
                    tb_logger.add_scalar('d_real.loss', loss_D_real, cnt)
                    tb_logger.add_scalar('d_fake.loss', loss_D_fake, cnt)
                    print_dict['Loss_G_GAN'] = loss_G_GAN
                    print_dict['Loss_real'] = loss_D_real.data
                    print_dict['Loss_fake'] = loss_D_fake.data
                if config['use_l2']:
                    tb_logger.add_scalar('l2.loss', loss_l2, cnt)
                    print_dict['Loss_L2'] = loss_l2.data

                log_iter(ep,
                         cnt % len(train_dataloader),
                         len(train_dataloader),
                         print_dict,
                         log_handle=log_handle)

            if loss_G != loss_G:
                print("NaN!!")
                exit(-2)

            cnt = cnt + 1
            # end of train iter loop

            if cnt % config['val_freq'] == 0 and config['val_freq'] > 0:
                val_loss = run_val(
                    model, validationLoss, val_dataloader,
                    os.path.join(save_path, 'val_%d_renders' % (ep)))

                if distributed:
                    val_loss = reduce_tensor(val_loss, config['world_size'])
                if rank == 0:
                    tb_logger.add_scalar('validation.loss', val_loss, cnt)
                    log_iter(ep,
                             cnt % len(train_dataloader),
                             len(train_dataloader), {"Loss_VGG": val_loss},
                             header="Validation loss: ",
                             log_handle=log_handle)

        if rank == 0:
            if (ep % config['save_freq'] == 0):
                fname = 'checkpoint_%d.ckpt' % (ep)
                fname = os.path.join(save_path, fname)
                print("Saving model...")
                save_weights(model, fname, distributed)
                optimizer_g_fname = os.path.join(
                    save_path, 'latest_optimizer_g_state.ckpt')
                torch.save(optimizer_G.state_dict(), optimizer_g_fname)
                if config['use_gan']:
                    fname = 'checkpoint_D_%d.ckpt' % (ep)
                    fname = os.path.join(save_path, fname)
                    save_weights(discriminator, fname, distributed)
                    optimizer_d_fname = os.path.join(
                        save_path, 'latest_optimizer_d_state.ckpt')
                    torch.save(optimizer_D.state_dict(), optimizer_d_fname)
예제 #6
0
def train(epoch, trainloader, steps_per_val, base_lr,
          total_epochs, optimizer, model, 
          adjust_learning_rate, print_freq, 
          image_freq, image_outdir, local_rank, sub_losses):    
    # Training
    model.train()
    
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*steps_per_val
    for i_iter, dp in enumerate(trainloader):
        def handle_batch():
            a, fg, bg = dp      # [B, 3, 3 or 1, H, W]
            #print (a.shape)
            out = model(a, fg, bg)
            L_alpha = out[0].mean()
            L_comp = out[1].mean()
            L_grad = out[2].mean()
            vis_alpha = L_alpha.detach().item()
            vis_comp = L_comp.detach().item()
            vis_grad = L_grad.detach().item()
            #L_temp = out[3].mean()
            #loss['L_total'] = 0.5 * loss['L_alpha'] + 0.5 * loss['L_comp'] + loss['L_grad'] + 0.5 * loss['L_temp']
            #loss['L_total'] = loss['L_alpha'] + loss['L_comp'] + loss['L_grad'] + loss['L_temp']
            loss = L_alpha + L_comp + L_grad

            model.zero_grad()
            loss.backward()
            optimizer.step()
            return loss.detach(), vis_alpha, vis_comp, vis_grad, out[3:]

        loss, vis_alpha, vis_comp, vis_grad, vis_out = handle_batch()

        reduced_loss = reduce_tensor(loss)
        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        torch_barrier()

        adjust_learning_rate(optimizer,
                            base_lr,
                            total_epochs * steps_per_val,
                            i_iter+cur_iters)

        if i_iter % print_freq == 0 and local_rank <= 0:
            msg = 'Iter:[{}/{}], Time: {:.2f}, '.format(\
                i_iter+cur_iters, total_epochs * steps_per_val, batch_time.average())
            msg += 'lr: {}, Avg. Loss: {:.6f} | Current: Loss: {:.6f}, '.format(
                [x['lr'] for x in optimizer.param_groups],
                ave_loss.average(), ave_loss.value())
            msg += '{}: {:.4f} {}: {:.4f} {}: {:.4f}'.format(
                sub_losses[0], vis_alpha, 
                sub_losses[1], vis_comp,
                sub_losses[2], vis_grad)
            logging.info(msg)
        
        if i_iter % image_freq == 0 and local_rank <= 0:
            write_image(image_outdir, vis_out, i_iter+cur_iters)
예제 #7
0
def train(train_loader, model, optimizer, epoch, writer, logger, config):
    device = torch.device("cuda")
    if config.label_smooth > 0:
        criterion = CrossEntropyLabelSmooth(config.n_classes,
                                            config.label_smooth).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    step_num = len(train_loader)
    cur_step = epoch * step_num
    cur_lr = optimizer.param_groups[0]['lr']
    if config.local_rank == 0:
        logger.info("Train Epoch {} LR {}".format(epoch, cur_lr))
        writer.add_scalar('train/lr', cur_lr, cur_step)

    model.train()

    for step, (X, y) in enumerate(train_loader):
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
        N = X.size(0)

        X, target_a, target_b, lam = data_utils.mixup_data(X,
                                                           y,
                                                           config.mixup_alpha,
                                                           use_cuda=True)

        optimizer.zero_grad()
        logits, logits_aux = model(X)
        # loss = criterion(logits, y)
        loss = data_utils.mixup_criterion(criterion, logits, target_a,
                                          target_b, lam)
        if config.aux_weight > 0:
            # loss_aux = criterion(logits_aux, y)
            loss_aux = data_utils.mixup_criterion(criterion, logits_aux,
                                                  target_a, target_b, lam)
            loss = loss + config.aux_weight * loss_aux

        if config.use_amp:
            from apex import amp
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.module.parameters(), config.grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
        if config.distributed:
            reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
            prec1 = utils.reduce_tensor(prec1, config.world_size)
            prec5 = utils.reduce_tensor(prec5, config.world_size)
        else:
            reduced_loss = loss.data

        losses.update(reduced_loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        torch.cuda.synchronize()
        if config.local_rank == 0 and (step % config.print_freq == 0
                                       or step == step_num):
            logger.info(
                "Train: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1,
                    config.epochs,
                    step,
                    step_num,
                    losses=losses,
                    top1=top1,
                    top5=top5))

        if config.local_rank == 0:
            writer.add_scalar('train/loss', reduced_loss.item(), cur_step)
            writer.add_scalar('train/top1', prec1.item(), cur_step)
            writer.add_scalar('train/top5', prec5.item(), cur_step)
            cur_step += 1

    if config.local_rank == 0:
        logger.info("Train: Epoch {:2d}/{} Final Prec@1 {:.4%}".format(
            epoch + 1, config.epochs, top1.avg))
예제 #8
0
def validate(testloader, model, test_size, local_rank, dataset_samples, tmp_folder='/dev/shm/val_tmp'):
    if local_rank <= 0:
        logging.info('Start evaluation...')
    model.eval()
    ave_loss = AverageMeter()
    c = len(dataset_samples[0]) // 2
    
    # We calculate L_dt as a mere indicator of temporal consistency.
    # Since we have sample_length=3 during validation, which means
    # there's only one prediction (the middle frame). Thus, here we 
    # first save the prediction to tmp_folder then compute L_dt in
    # one pass.
    with torch.no_grad():
        iterator = tqdm(testloader, ascii=True) if local_rank <= 0 else testloader
        for batch in iterator:
            fg, bg, a, idx = batch      # [B, 3, 3 or 1, H, W]
            def handle_batch():
                out = model(a, fg, bg)
                L_alpha = out[0].mean()
                L_comp = out[1].mean()
                L_grad = out[2].mean()
                loss = L_alpha + L_comp + L_grad
                return loss.detach(), out[6].detach(), out[7].detach()
            
            loss, tris, alphas = handle_batch()
            reduced_loss = reduce_tensor(loss)

            for i in range(tris.shape[0]):
                fn = dataset_samples[idx[i].item()][c]
                outpath = os.path.join(tmp_folder, fn)
                os.makedirs(os.path.dirname(outpath), exist_ok=True)
                pred = np.uint8((alphas[i, c, 0] * 255).cpu().numpy())
                tri = tris[i, c, 0] * 255
                tri = np.uint8(((tri > 0) * (tri < 255)).cpu().numpy() * 255)
                gt = np.uint8(a[i, c, 0].numpy())
                out = np.stack([pred, tri, gt], axis=-1)
                cv.imwrite(outpath, out)

            ave_loss.update(reduced_loss.item())
    loss = ave_loss.average()

    if local_rank <= 0:
        logging.info('Validation loss: {:.6f}'.format(ave_loss.average()))

        def _read_output(fn):
            fn = os.path.join(tmp_folder, fn)
            preds = cv.imread(fn)
            a, m, g = np.split(preds, 3, axis=-1)
            a = np.float32(np.squeeze(a)) / 255.0
            m = np.squeeze(m) != 0
            g = np.float32(np.squeeze(g)) / 255.0
            return a, g, m 
        res = 0.
        for sample in tqdm(dataset_samples, ascii=True):
            a, g, m = _read_output(sample[c])
            ha, hg, _ = _read_output(sample[c+1])
            dadt = a - ha
            dgtdt = g - hg
            if np.sum(m) == 0:
                continue
            res += np.mean(np.abs(dadt[m] - dgtdt[m]))
        res /= float(len(dataset_samples))
        logging.info('Average L_dt: {:.6f}'.format(res))
        loss += res
        shutil.rmtree(tmp_folder)

    torch_barrier()
    return loss
예제 #9
0
def search(train_loader, valid_loader, model, optimizer, w_optim, alpha_optim,
           layer_idx, epoch, writer, logger, config):
    # interactive retrain and kl

    device = torch.device("cuda")
    criterion = nn.CrossEntropyLoss().to(device)
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()
    losses_interactive = utils.AverageMeter()
    losses_cls = utils.AverageMeter()
    losses_reg = utils.AverageMeter()

    step_num = len(train_loader)
    step_num = int(step_num * config.sample_ratio)

    cur_step = epoch * step_num
    cur_lr_search = w_optim.param_groups[0]['lr']
    cur_lr_main = optimizer.param_groups[0]['lr']
    if config.local_rank == 0:
        logger.info("Train Epoch {} Search LR {}".format(epoch, cur_lr_search))
        logger.info("Train Epoch {} Main LR {}".format(epoch, cur_lr_main))
        writer.add_scalar('retrain/lr', cur_lr_search, cur_step)

    model.train()

    for step, ((trn_X, trn_y),
               (val_X, val_y)) in enumerate(zip(train_loader, valid_loader)):
        if step > step_num:
            break

        trn_X, trn_y = trn_X.to(device,
                                non_blocking=True), trn_y.to(device,
                                                             non_blocking=True)
        val_X, val_y = val_X.to(device,
                                non_blocking=True), val_y.to(device,
                                                             non_blocking=True)
        N = trn_X.size(0)

        #use valid data
        alpha_optim.zero_grad()
        optimizer.zero_grad()

        logits_search, emsemble_logits_search = model(val_X,
                                                      layer_idx,
                                                      super_flag=True)
        logits_main, emsemble_logits_main = model(val_X,
                                                  layer_idx,
                                                  super_flag=False)

        loss_cls = (criterion(logits_search, val_y) +
                    criterion(logits_main, val_y)) / config.loss_alpha
        loss_interactive = Loss_interactive(
            emsemble_logits_search, emsemble_logits_main, config.loss_T,
            config.interactive_type) * config.loss_alpha

        loss_regular = 0 * loss_cls
        if config.regular:
            reg_decay = max(
                config.regular_coeff *
                (1 - float(epoch - config.pretrain_epochs) /
                 ((config.search_iter - config.pretrain_epochs) *
                  config.search_iter_epochs * config.regular_ratio)), 0)
            # normal cell
            op_opt = ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect']
            op_groups = []
            for idx in range(layer_idx, 3):
                for op_dx in op_opt:
                    op_groups.append((idx - layer_idx, op_dx))
            loss_regular = loss_regular + model.module.add_alpha_regularization(
                op_groups, weight_decay=reg_decay, method='L1', reduce=False)

            # reduction cell
            # op_opt = []
            op_opt = ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect']
            op_groups = []
            for i in range(layer_idx, 3):
                for op_dx in op_opt:
                    op_groups.append((i - layer_idx, op_dx))
            loss_regular = loss_regular + model.module.add_alpha_regularization(
                op_groups, weight_decay=reg_decay, method='L1', normal=False)

        loss = loss_cls + loss_interactive + loss_regular
        loss.backward()
        nn.utils.clip_grad_norm_(model.module.parameters(), config.w_grad_clip)
        optimizer.step()
        alpha_optim.step()

        prec1, prec5 = utils.accuracy(logits_main, val_y, topk=(1, 5))
        if config.distributed:
            reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
            reduced_loss_interactive = utils.reduce_tensor(
                loss_interactive.data, config.world_size)
            reduced_loss_cls = utils.reduce_tensor(loss_cls.data,
                                                   config.world_size)
            reduced_loss_reg = utils.reduce_tensor(loss_regular.data,
                                                   config.world_size)
            prec1 = utils.reduce_tensor(prec1, config.world_size)
            prec5 = utils.reduce_tensor(prec5, config.world_size)

        else:
            reduced_loss = loss.data
            reduced_loss_interactive = loss_interactive.data
            reduced_loss_cls = loss_cls.data
            reduced_loss_reg = loss_regular.data

        losses.update(reduced_loss.item(), N)
        losses_interactive.update(reduced_loss_interactive.item(), N)
        losses_cls.update(reduced_loss_cls.item(), N)
        losses_reg.update(reduced_loss_reg.item(), N)

        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        torch.cuda.synchronize()
        if config.local_rank == 0 and (step % config.print_freq == 0
                                       or step == step_num):
            logger.info(
                "Train_2: Layer {}/{} Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Loss_interactive {losses_interactive.avg:.3f} Losses_cls {losses_cls.avg:.3f} Losses_reg {losses_reg.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    layer_idx + 1,
                    config.layer_num,
                    epoch + 1,
                    config.search_iter * config.search_iter_epochs,
                    step,
                    step_num,
                    losses=losses,
                    losses_interactive=losses_interactive,
                    losses_cls=losses_cls,
                    losses_reg=losses_reg,
                    top1=top1,
                    top5=top5))

        if config.local_rank == 0:
            writer.add_scalar('retrain/loss', reduced_loss.item(), cur_step)
            writer.add_scalar('retrain/top1', prec1.item(), cur_step)
            writer.add_scalar('retrain/top5', prec5.item(), cur_step)
            cur_step += 1

        w_optim.zero_grad()
        logits_search_train, _ = model(trn_X, layer_idx, super_flag=True)
        loss_cls_train = criterion(logits_search_train, trn_y)
        loss_train = loss_cls_train
        loss_train.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.module.parameters(), config.w_grad_clip)
        # only update w
        w_optim.step()

        # alpha_optim.step()
        if config.distributed:
            reduced_loss_cls_train = utils.reduce_tensor(
                loss_cls_train.data, config.world_size)
            reduced_loss_train = utils.reduce_tensor(loss_train.data,
                                                     config.world_size)
        else:
            reduced_loss_cls_train = reduced_loss_cls_train.data
            reduced_loss_train = reduced_loss_train.data

        if config.local_rank == 0 and (step % config.print_freq == 0
                                       or step == step_num - 1):
            logger.info("Train_1: Loss_cls: {:.3f} Loss: {:.3f}".format(
                reduced_loss_cls_train.item(), reduced_loss_train.item()))

    if config.local_rank == 0:
        logger.info(
            "Train_2: Layer {}/{} Epoch {:2d}/{} Final Prec@1 {:.4%}".format(
                layer_idx + 1, config.layer_num, epoch + 1,
                config.search_iter * config.search_iter_epochs, top1.avg))
예제 #10
0
def retrain_warmup(valid_loader, model, optimizer, layer_idx, epoch, writer,
                   logger, super_flag, retrain_epochs, config):

    device = torch.device("cuda")
    criterion = nn.CrossEntropyLoss().to(device)
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    step_num = len(valid_loader)
    step_num = int(step_num * config.sample_ratio)

    cur_step = epoch * step_num
    cur_lr = optimizer.param_groups[0]['lr']
    if config.local_rank == 0:
        logger.info("Warmup Epoch {} LR {:.3f}".format(epoch + 1, cur_lr))
        writer.add_scalar('warmup/lr', cur_lr, cur_step)

    model.train()

    for step, (val_X, val_y) in enumerate(valid_loader):
        if step > step_num:
            break

        val_X, val_y = val_X.to(device,
                                non_blocking=True), val_y.to(device,
                                                             non_blocking=True)
        N = val_X.size(0)

        optimizer.zero_grad()
        logits_main, _ = model(val_X, layer_idx, super_flag=super_flag)
        loss = criterion(logits_main, val_y)
        loss.backward()

        nn.utils.clip_grad_norm_(model.module.parameters(), config.w_grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits_main, val_y, topk=(1, 5))
        if config.distributed:
            reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
            prec1 = utils.reduce_tensor(prec1, config.world_size)
            prec5 = utils.reduce_tensor(prec5, config.world_size)

        else:
            reduced_loss = loss.data

        losses.update(reduced_loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        torch.cuda.synchronize()
        if config.local_rank == 0 and (step % config.print_freq == 0
                                       or step == step_num):
            logger.info(
                "Warmup: Layer {}/{} Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f}  "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    layer_idx + 1,
                    config.layer_num,
                    epoch + 1,
                    retrain_epochs,
                    step,
                    step_num,
                    losses=losses,
                    top1=top1,
                    top5=top5))

        if config.local_rank == 0:
            writer.add_scalar('retrain/loss', reduced_loss.item(), cur_step)
            writer.add_scalar('retrain/top1', prec1.item(), cur_step)
            writer.add_scalar('retrain/top5', prec5.item(), cur_step)
            cur_step += 1

    if config.local_rank == 0:
        logger.info(
            "Warmup: Layer {}/{} Epoch {:2d}/{} Final Prec@1 {:.4%}".format(
                layer_idx + 1, config.layer_num, epoch + 1, retrain_epochs,
                top1.avg))
예제 #11
0
    def _train_loop(self, iteration, batch, epoch, batch_start_time,
                    train_objs, train_args):
        if train_args.use_relabel:
            # load ReLabel ground truth
            image, target_original, target_relabel = train_objs.batch_fn(
                batch=batch, num_classes=train_args.num_classes, mode='train')
            target_original = target_original.cuda()
            target = target_relabel
        else:
            # load original imagenet ground truth
            image, target_original = train_objs.batch_fn(
                batch=batch, num_classes=train_args.num_classes, mode='train')
            target_original = target_original.cuda()
            target = target_original

        batch_size = image.size(0)

        current_lr = adjust_learning_rate(
            optimizer=train_objs.optimizer,
            epoch=epoch,
            iteration=iteration,
            lr_decay_type=self.args.optim.lr.decay_type,
            epochs=train_args.epochs,
            train_len=train_args.len,
            warmup_lr=train_args.warmup_lr,
            warmup_epochs=train_args.warmup_epochs)

        image = image.cuda()

        # apply cutmix augmentation
        if self.args.data.cutmix.prob > 0. and self.args.data.cutmix.beta > 0.:
            cutmix_args = mch(
                beta=self.args.data.cutmix.beta,
                prob=self.args.data.cutmix.prob,
                num_classes=self.context.num_classes,
                smoothing=self.args.optim.label_smoothing,
                disable=epoch >=
                (self.args.optim.epochs - self.args.data.cutmix.off_epoch))
            image, target = cutmix_batch(image, target, cutmix_args)

        # forward and compute loss
        output = train_objs.model(image)
        loss = self.criterion(output, target)

        train_objs.optimizer.zero_grad()
        with amp.scale_loss(loss, train_objs.optimizer) as scaled_loss:
            scaled_loss.backward()

        # optimizer steps
        train_objs.optimizer.step()
        train_objs.optimizer.zero_grad()

        if iteration % self.args.util.print_freq != 0 or iteration == 0:
            return

        # print intermediate results
        target_squeezed = squeeze_one_hot(target_original)
        train_objs.meters = compute_accuracy_dist(
            output=output,
            target_squeezed=target_squeezed,
            meters=train_objs.meters,
            world_size=self.context.world_size)

        reduced_loss = reduce_tensor(loss.data, self.context.world_size)
        train_objs.meters.get('losses').update(reduced_loss.item(), batch_size)

        torch.cuda.synchronize()
        train_objs.meters.get('batch_time').update(
            (time.time() - batch_start_time))

        if self.context.gpu_no != 0:
            return

        PrintCollection.print_train_batch_info(args=self.args,
                                               epoch=epoch,
                                               iteration=iteration,
                                               train_len=train_args.len,
                                               meters=train_objs.meters,
                                               current_lr=current_lr)

        sys.stdout.flush()