def main():
    args = EigenmetricRegressionSolver.get_eigenmetric_regression_arguments()
    model = AlbertForEigenmetricRegression.from_scratch(
        num_labels=len(ALBERT_EIGENMETRICS),
        top_comment_pretrained_model_name_or_path=args.
        top_comment_pretrained_model_name_or_path,
        post_pretrained_model_name_or_path=args.
        post_pretrained_model_name_or_path,
        classifier_dropout_prob=args.classifier_dropout_prob,
        meta_data_size=len(ALBERT_META_FEATURES),
        subreddit_pretrained_path=args.subreddit_pretrained_path,
        num_subreddit_embeddings=NUM_SUBREDDIT_EMBEDDINGS,
        subreddit_embeddings_size=SUBREDDIT_EMBEDDINGS_SIZE)
    if args.freeze_alberts:
        model = model.freeze_bert()

    save_dict = {
        "model_construct_params_dict": model.param_dict(),
        "state_dict": model.state_dict()
    }

    logx.initialize(logdir=args.output_dir,
                    coolname=True,
                    tensorboard=False,
                    no_timestamp=False,
                    eager_flush=True)

    logx.save_model(save_dict, metric=0, epoch=0, higher_better=False)
    def __train_per_epoch(self, epoch_idx: int, steps_per_eval: int):
        with tqdm(total=len(self.train_dataloader),
                  desc=f"Epoch {epoch_idx}") as pbar:
            for batch_idx, batch in enumerate(self.train_dataloader):
                global_step = epoch_idx * len(
                    self.train_dataloader) + batch_idx
                loss = self.__training_step(batch)
                if self.n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                loss.backward()
                logx.metric(
                    'train', {
                        "tr_loss": loss.item(),
                        "learning_rate": self.scheduler.get_last_lr()[0]
                    }, global_step)
                pbar.set_postfix_str(f"tr_loss: {loss.item():.5f}")
                # update weights
                self.optimizer.step()
                self.scheduler.step()  # Update learning rate schedule
                if batch_idx % steps_per_eval == 0:
                    # validate and save checkpoints
                    # downsample a subset of dev dataset
                    eval_dataset = self.dev_dataloader.dataset
                    subset_size = len(eval_dataset) // 500
                    eval_sampled_dataloader = DataLoader(
                        Subset(
                            self.dev_dataloader.dataset,
                            random.sample(range(len(eval_dataset)),
                                          subset_size)),
                        shuffle=True,
                        batch_size=self.batch_size,
                        pin_memory=True)
                    mean_loss, metrics_scores, _, _ = self.validate(
                        eval_sampled_dataloader)
                    logx.metric('val', metrics_scores, global_step)
                    if self.n_gpu > 1:
                        save_dict = {
                            "model_construct_params_dict":
                            self.model.module.param_dict(),
                            "state_dict":
                            self.model.module.state_dict(),
                            "solver_construct_params_dict":
                            self.state_dict(),
                            "optimizer":
                            self.optimizer.state_dict()
                        }
                    else:
                        save_dict = {
                            "model_construct_params_dict":
                            self.model.param_dict(),
                            "state_dict": self.model.state_dict(),
                            "solver_construct_params_dict": self.state_dict(),
                            "optimizer": self.optimizer.state_dict()
                        }

                    logx.save_model(save_dict,
                                    metric=mean_loss,
                                    epoch=global_step,
                                    higher_better=False)
                pbar.update(1)
Beispiel #3
0
def validation(args, model, device, val_loader, optimizer, epoch, criterion):
    model.eval()
    n_val = len(val_loader)
    val_loss = 0
    val_psnr = 0
    for batch_idx, batch_data in enumerate(val_loader):
        batch_ldr0, batch_ldr1, batch_ldr2 = batch_data['input0'].to(device), batch_data['input1'].to(device), \
                                             batch_data['input2'].to(device)
        label = batch_data['label'].to(device)

        with torch.no_grad():
            pred = model(batch_ldr0, batch_ldr1, batch_ldr2)
            pred = range_compressor_tensor(pred)
            pred = torch.clamp(pred, 0., 1.)

        loss = criterion(pred, label)
        psnr = batch_PSNR(pred, label, 1.0)
        logx.msg('Validation set: PSNR: {:.4f}'.format(psnr))

        iteration = (epoch - 1) * len(val_loader) + batch_idx
        if epoch % 100 == 0:
            logx.add_image('val/input1', batch_ldr0[0][[2, 1, 0], :, :],
                           iteration)
            logx.add_image('val/input2', batch_ldr1[0][[2, 1, 0], :, :],
                           iteration)
            logx.add_image('val/input3', batch_ldr2[0][[2, 1, 0], :, :],
                           iteration)
            logx.add_image('val/pred', pred[0][[2, 1, 0], :, :], iteration)
            logx.add_image('val/gt', label[0][[2, 1, 0], :, :], iteration)

        val_loss += loss
        val_psnr += psnr

    val_loss /= n_val
    val_psnr /= n_val
    logx.msg('Validation set: Average loss: {:.4f}'.format(val_loss))
    logx.msg('Validation set: Average PSNR: {:.4f}\n'.format(val_psnr))

    # capture metrics
    metrics = {'psnr': val_psnr}
    logx.metric('val', metrics, epoch)
    # save_model
    save_dict = {
        'epoch': epoch + 1,
        'arch': 'AHDRNet',
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }

    logx.save_model(save_dict,
                    epoch=epoch,
                    metric=val_loss,
                    higher_better=True)
Beispiel #4
0
def test_epoch(epoch):
    model.eval()
    losses = 0.0
    total, correct = 0.0, 0.0
    with torch.no_grad():
        for step, (x, y) in enumerate(val_loader):
            x, y = x.to(config.device), y.to(config.device)
            out = model(x)
            loss = criterion(out, y)
            losses += loss.cpu().detach().numpy()
            _, pred = torch.max(out.data, 1)
            total += y.size(0)
            correct += (pred == y).squeeze().sum().cpu().numpy()
    save_dict = {
        'state_dict': model.state_dict()
    }
    logx.msg("epoch {} validation loss {} validation acc {}".format(epoch, losses / (step + 1), correct / total))
    logx.metric('val', {'loss': losses / (step + 1), 'acc': correct / total})
    logx.save_model(save_dict, losses, epoch, higher_better=False, delete_old=True)
Beispiel #5
0
def test(args, model, device, test_loader, epoch, optimizer):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    accuracy = 100. * correct / len(test_loader.dataset)
    logx.msg(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset), accuracy))

    # capture metrics
    metrics = {'loss': test_loss, 'accuracy': accuracy}
    logx.metric('val', metrics, epoch)

    # save model
    save_dict = {
        'epoch': epoch + 1,
        'arch': 'lenet',
        'state_dict': model.state_dict(),
        'accuracy': accuracy,
        'optimizer': optimizer.state_dict()
    }

    logx.save_model(save_dict,
                    metric=accuracy,
                    epoch=epoch,
                    higher_better=True)
Beispiel #6
0
def train():
    for time in range(5):
        logx.initialize(get_logdir("../runs"),
                        tensorboard=True,
                        coolname=False)

        model.load_state_dict(
            torch.load("..\\runs\exp10\last_checkpoint_ep0.pth")
            ['state_dict'])  # warmup

        dataset_train = TrainDataset(
            '../' + cfg.root_folder +
            '/five_fold/train_kfold_{}.csv'.format(time),
            '../' + cfg.root_folder + '/train/', train_transform)
        train_loader = DataLoader(dataset_train,
                                  batch_size=cfg.bs,
                                  shuffle=True)
        test_data = TrainDataset(
            '../' + cfg.root_folder +
            '/five_fold/test_kfold_{}.csv'.format(time),
            '../' + cfg.root_folder + '/train/',
        )
        test_load = DataLoader(test_data, batch_size=cfg.bs, shuffle=False)

        # train
        for epoch in range(cfg.epoch):
            loss_epoch = 0
            total = 0
            correct = 0
            for i, (x, y) in enumerate(train_loader, 1):
                x, y = x.to(device), y.to(device)
                y_hat = model(x)
                # 计算正确率
                total += x.size(0)
                _, predict = torch.max(y_hat.data, dim=1)
                correct += (predict == y).sum().item()

                # 损失
                loss = criterion(y_hat, y)
                loss_epoch += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # 过程可视化
                if i % 30 == 0:
                    print(
                        'epoch:%d,  enumerate:%d,  loss_avg:%f,  now_acc:%f' %
                        (epoch, i, loss_epoch / i, correct / total))

            # epoch matric 可视化
            train_loss = loss_epoch / i
            train_acc = (correct / total) * 100
            logx.metric('train', {'loss': train_loss, 'acc': train_acc}, epoch)

            # valid
            # 开发集正确率
            correct = 0
            total = 0
            val_loss = 0
            with torch.no_grad():
                for i, (img, label) in enumerate(test_load, 1):
                    img, label = img.to(device), label.to(device)
                    output = model(img)
                    loss = criterion(output, label)
                    val_loss += loss.cpu().item()
                    _, predicted = torch.max(output.data, dim=1)  # 最大值,位置
                    total += img.size(0)
                    correct += (predicted == label).sum().item()
            val_acc = (100 * correct / total)
            val_loss /= i
            logx.metric('val', {'loss': val_loss, 'acc': val_acc}, epoch)
            # epoch lossand other metric
            print(
                'epoch over; train_loss:%f, val_loss:%f, train_acc=%f, val_acc:%f'
                % (train_loss, val_loss, train_acc, val_acc))
            logx.save_model({
                'state_dict': model.state_dict(),
                'epoch': epoch
            },
                            val_acc,
                            higher_better=True,
                            epoch=epoch,
                            delete_old=True)
            scheduler.step()
Beispiel #7
0
def eval_metrics(iou_acc, args, net, optim, val_loss, epoch, mf_score=None):
    """
    Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory
    overflow for large dataset) Only applies to eval/eval.py
    """
    was_best = False

    iou_per_scale = {}
    iou_per_scale[1.0] = iou_acc
    if args.amp or args.apex:
        iou_acc_tensor = torch.cuda.FloatTensor(iou_acc)
        torch.distributed.all_reduce(iou_acc_tensor,
                                     op=torch.distributed.ReduceOp.SUM)
        iou_per_scale[1.0] = iou_acc_tensor.cpu().numpy()

    scales = [1.0]

    # Only rank 0 should save models and calculate metrics
    if args.global_rank != 0:
        return None, 0

    hist = iou_per_scale[args.default_scale]
    iu, acc, acc_cls = calculate_iou(hist)
    iou_per_scale = {args.default_scale: iu}

    # calculate iou for other scales
    for scale in scales:
        if scale != args.default_scale:
            iou_per_scale[scale], _, _ = calculate_iou(iou_per_scale[scale])

    print_evaluate_results(hist,
                           iu,
                           epoch=epoch,
                           iou_per_scale=iou_per_scale,
                           log_multiscale_tb=args.log_msinf_to_tb)

    freq = hist.sum(axis=1) / hist.sum()
    mean_iu = np.nanmean(iu)
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()

    metrics = {
        'loss': val_loss.avg,
        'mean_iu': mean_iu,
        'acc_cls': acc_cls,
        'acc': acc,
    }
    logx.metric('val', metrics, epoch)
    logx.msg('Mean: {:2.2f}'.format(mean_iu * 100))

    save_dict = {
        'epoch':
        epoch,
        'arch':
        args.arch,
        'num_classes':
        cfg.DATASET_INST.num_classes,
        'state_dict':
        net.state_dict(),
        'optimizer':
        optim.lcl_optimizer.state_dict() if args.heat else optim.state_dict(),
        'mean_iu':
        mean_iu,
        'command':
        ' '.join(sys.argv[1:])
    }
    logx.save_model(save_dict, metric=mean_iu, epoch=epoch)
    torch.cuda.synchronize()

    if mean_iu > args.best_record['mean_iu']:
        was_best = True

        args.best_record['val_loss'] = val_loss.avg
        if mf_score is not None:
            args.best_record['mask_f1_score'] = mf_score.avg
        args.best_record['acc'] = acc
        args.best_record['acc_cls'] = acc_cls
        args.best_record['fwavacc'] = fwavacc
        args.best_record['mean_iu'] = mean_iu
        args.best_record['epoch'] = epoch

    logx.msg('-' * 107)
    if mf_score is None:
        fmt_str = ('{:5}: [epoch {}], [val loss {:0.5f}], [acc {:0.5f}], '
                   '[acc_cls {:.5f}], [mean_iu {:.5f}], [fwavacc {:0.5f}]')
        current_scores = fmt_str.format('this', epoch, val_loss.avg, acc,
                                        acc_cls, mean_iu, fwavacc)
        logx.msg(current_scores)
        best_scores = fmt_str.format('best', args.best_record['epoch'],
                                     args.best_record['val_loss'],
                                     args.best_record['acc'],
                                     args.best_record['acc_cls'],
                                     args.best_record['mean_iu'],
                                     args.best_record['fwavacc'])
        logx.msg(best_scores)
    else:
        fmt_str = ('{:5}: [epoch {}], [val loss {:0.5f}], [mask f1 {:.5f} ] '
                   '[acc {:0.5f}], '
                   '[acc_cls {:.5f}], [mean_iu {:.5f}], [fwavacc {:0.5f}]')
        current_scores = fmt_str.format('this', epoch, val_loss.avg,
                                        mf_score.avg, acc, acc_cls, mean_iu,
                                        fwavacc)
        logx.msg(current_scores)
        best_scores = fmt_str.format(
            'best', args.best_record['epoch'], args.best_record['val_loss'],
            args.best_record['mask_f1_score'], args.best_record['acc'],
            args.best_record['acc_cls'], args.best_record['mean_iu'],
            args.best_record['fwavacc'])
        logx.msg(best_scores)
    logx.msg('-' * 107)

    return was_best, mean_iu
Beispiel #8
0
                              batch_size=args.batch_size,
                              shuffle=True)
    valid_dataset = Train_Dataset('./data/new_valid.csv',
                                  './data/train/',
                                  transform=valid_transformer)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size)
    best_accuracy = 0
    for epoch in range(args.epochs):
        print("epoch:" + str(epoch))
        train_acc, train_loss = train(my_model,
                                      train_loader,
                                      optimizer,
                                      scheduler=scheduler)
        metric_train = {'train_acc': train_acc, 'train_loss': train_loss}
        logx.metric('train', metric_train, epoch)
        # torch.save({'state_dict}': my_model.state_dict()}, './weights/resnet50_last.pth')
        valid_acc, valid_loss = valid(my_model, valid_loader)
        metric_valid = {'valid_acc': valid_acc, 'valid_loss': valid_loss}
        logx.metric('val', metric_valid, epoch)
        if valid_acc > best_accuracy:
            best_accuracy = valid_acc
            torch.save({'state_dict}': my_model.state_dict()},
                       './logs/exp9/highest_valid_acc.pth')
        logx.save_model({'state_dict}': my_model.state_dict()},
                        valid_loss,
                        epoch,
                        higher_better=False,
                        delete_old=True)
        print("current_acc:{0}, best_acc:{1}".format(valid_acc, best_accuracy))
Beispiel #9
0
 def __train_per_epoch(self, epoch_idx, steps_per_eval):
     # with tqdm(total=len(self.train_dataloader), desc=f"Epoch {epoch_idx}") as pbar:
     for batch_idx, batch in enumerate(self.train_dataloader):
         # assume that the whole input matrix fits the GPU memory
         global_step = epoch_idx * len(self.train_dataloader) + batch_idx
         training_set_loss, training_set_outputs, training_set_output_similarity = self.__training_step(
             batch)
         if batch_idx + 1 == len(self.train_dataloader):
             # validate and save checkpoints
             developing_set_outputs, developing_set_metrics_scores, developing_set_output_similarity = \
                 self.validate(self.dev_dataloader)
             # TODO: this part can be optimized to batchwise computing
             if self.record_training_loss_per_epoch:
                 training_set_metrics_scores, _ = \
                     self.get_scores(self.train_decoder,
                                     training_set_outputs,
                                     self.train_dataloader.dataset.anchor_idx)
             else:
                 training_set_metrics_scores = dict()
             training_set_metrics_scores['loss'] = training_set_loss.item()
             if self.scheduler:
                 training_set_metrics_scores[
                     'learning_rate'] = self.scheduler.get_last_lr()[0]
             logx.metric('train', training_set_metrics_scores, global_step)
             logx.metric('val', developing_set_metrics_scores, global_step)
             if self.n_gpu > 1:
                 save_dict = {
                     "model_construct_dict": self.model.module.config,
                     "model_state_dict": self.model.module.state_dict(),
                     "solver_construct_params_dict":
                     self.construct_param_dict,
                     "optimizer": self.optimizer.state_dict(),
                     "train_scores": training_set_metrics_scores,
                     "train_input_embedding":
                     self.train_dataloader.dataset.x,
                     "train_input_similarity":
                     self.train_dataloader.dataset.input_similarity,
                     "train_output_embedding": training_set_outputs,
                     "train_output_similarity":
                     training_set_output_similarity,
                     "dev_scores": developing_set_metrics_scores,
                     "dev_input_embeddings": self.dev_dataloader.dataset.x,
                     "dev_input_similarity":
                     self.dev_dataloader.dataset.input_similarity,
                     "dev_output_embedding": developing_set_outputs,
                     "dev_output_similarity":
                     developing_set_output_similarity,
                 }
             else:
                 save_dict = {
                     "model_construct_dict": self.model.config,
                     "model_state_dict": self.model.state_dict(),
                     "solver_construct_params_dict":
                     self.construct_param_dict,
                     "optimizer": self.optimizer.state_dict(),
                     "train_scores": training_set_metrics_scores,
                     "train_input_embedding":
                     self.train_dataloader.dataset.x,
                     "train_input_similarity":
                     self.train_dataloader.dataset.input_similarity,
                     "train_output_embedding": training_set_outputs,
                     "train_output_similarity":
                     training_set_output_similarity,
                     "dev_scores": developing_set_metrics_scores,
                     "dev_input_embeddings": self.dev_dataloader.dataset.x,
                     "dev_input_similarity":
                     self.dev_dataloader.dataset.input_similarity,
                     "dev_output_embedding": developing_set_outputs,
                     "dev_output_similarity":
                     developing_set_output_similarity,
                 }
             logx.save_model(
                 save_dict,
                 metric=developing_set_metrics_scores['Recall@1'],
                 epoch=global_step,
                 higher_better=True)
Beispiel #10
0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = get_model_by_name(cfg.model_name)
    if opt.pretrained:
        net.load_state_dict(torch.load(opt.pretrained)['state_dict'])
    net.to(device)
    # 定义损失函数和优化方式
    criterion = JointLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=LR,
                          momentum=0.9,
                          weight_decay=0.001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     last_epoch=start_epoch -
                                                     1,
                                                     T_max=EPOCH,
                                                     eta_min=1e-10)

    for i in range(start_epoch, start_epoch + EPOCH):
        train(i)
        scheduler.step()
        valid_acc = valid(i)
        logx.save_model({
            'state_dict': net.state_dict(),
            'epoch': i
        },
                        metric=valid_acc,
                        epoch=i,
                        higher_better=True,
                        delete_old=True)
Beispiel #11
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        logx.msg("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    args.arch = 'resnet18'
    logx.msg("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](num_classes=args.num_classes)
    '''
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    '''

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logx.msg("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logx.msg("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logx.msg("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, epoch)

        # remember best acc@1 and save checkpoint
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_dict = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict()
            }
            logx.save_model(save_dict,
                            metric=acc1,
                            epoch=epoch,
                            higher_better=True)