def build_models(self):
        self.G = Generator(self.config.z_dim, self.config.g_conv_dim,
                           self.num_of_classes).to(self.device)
        self.D = Discriminator(self.config.d_conv_dim,
                               self.num_of_classes).to(self.device)

        # Loss and optimizer
        # self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.G_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()),
            self.config.g_lr, [self.config.beta1, self.config.beta2])
        self.D_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()),
            self.config.d_lr, [self.config.beta1, self.config.beta2])

        # Start with pretrained model (if it exists)
        if self.config.pretrained_model != '':
            utils.load_pretrained_model(self)

        if 'cuda' in self.device.type and self.config.parallel and torch.cuda.device_count(
        ) > 1:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # print networks
        print(self.G)
        print(self.D)
Beispiel #2
0
    def configure_model(self):
        # load pre-trained model
        if self.train_params.resume:
            return self.resume_training()
        else:
            symbol, arg_params, aux_params = load_pretrained_model(
                self.model_params.url_prefix,
                self.model_params.name,
                self.model_params.model_epoch,
                self.model_params.dir,
                ctx=self.ctx)
            #self.set_use_global_stats_json()
            #symbol = mx.symbol.load(self.model_params.dir + self.model_params.name + '-symbol.json')
            # adjust the network to satisfy the required input
            if self.mode == 'spatial':
                new_symbol, new_arg_params = self.refactor_model_spatial(
                    symbol, arg_params)
                new_aux_params = aux_params
            elif self.mode == 'temporal':
                new_symbol, new_arg_params, new_aux_params = self.refactor_model_temporal(
                    symbol, arg_params, aux_params)
            else:
                raise NotImplementedError(
                    'The refactoring method-{} for the model has not be implemented yet'
                    .format(self.mode))

            new_symbol.save(self.model_params.dir + self.model_params.name +
                            '-' + self.mode + '-symbol.json')
            self.set_use_global_stats_json()
            new_symbol = mx.symbol.load(self.model_params.dir +
                                        self.model_params.name + '-' +
                                        self.mode + '-symbol.json')
            return new_symbol, new_arg_params, new_aux_params
Beispiel #3
0
def load_model_and_continue_training(model_save_path, json_parameters_path,
                                     save_new_result_sheet):
    # append function name to the call sequence
    calling_sequence.append("[load_model_and_continue_training]==>>")
    print(" ==============================================")
    print(
        " [INFO] Entering function[load_model_and_continue_training]  in core.py"
    )
    # -----------------------------------------------------------------------------------------------------
    # load saved parameters
    parameters = load_parameters(json_parameters_path)
    print("loaded parameters", parameters)
    # load saved model
    model = load_pretrained_model(model_path=model_save_path)
    # -----------------------------------------------------------------------------------------------------
    # get train generator , validation_generator, test_generator, parameters from  prepare_train_valid_data
    # after loading train, validation and test data (classes names, and data for each class)
    train_generator, validation_generator, test_generator, parameters = prepare_train_valid_data(
        parameters)
    # ------------------------------------------------------------------------------------------------------
    # check if train on parallel gpus
    if (parameters['device'] == 'gpu_parallel'):
        print(" [INFO] target multi gpus...")
        _parallel = True
    else:
        print(" [INFO] target single gpu...")
        _parallel = False
    #-----------------------------------------------------------------------------------------------------------------
    # start training
    history, parameters = train(model,
                                parameters,
                                train_generator,
                                validation_generator,
                                test_generator,
                                parallel=_parallel)
    #-----------------------------------------------------------------------------------------------------------------
    # apply testset
    # TODO: change this to work with generators
    max_prediction_time, parameters = calculate_accuaracy_data_set(
        DataPath=parameters['test_data_path'],
        parameters=parameters,
        model=model)
    print("max prediction time = ", max_prediction_time)
    #-----------------------------------------------------------------------------------------------------------------
    # save train result
    save_train_result(history,
                      parameters,
                      initiate_new_result_sheet=save_new_result_sheet)
    #-----------------------------------------------------------------------------------------------------------------
    update_accuracy_in_paramters_on_end_then_save_to_json(parameters, history)
    #-----------------------------------------------------------------------------------------------------------------
    # clear seassion
    del model, train_generator, validation_generator, parameters, history
    K.clear_session()

    print(" [INFO] calling sequence -> ", calling_sequence)
    calling_sequence.clear()

    print(" [INFO] Leaving function[load_model_and_continue_training]")
    print(" ==============================================")
Beispiel #4
0
def main():

    global args, best_prec1

    args = parser.parse_args()
    args.workers = 16
    args.seed = long(time.time())

    # prepare the test dataset
    root = '../dataset/'
    test_step = 5
    test_pair = load_test_data(root, test_step)

    if args.use_loc:
        model_name = 'locmodel_best.pth.tar'
        use_loc = True
        use_trk = False
        if args.use_trk:
            model_name = 'trkmodel_best.pth.tar'
            use_loc = True
            use_trk = True
    else:
        model_name = 'denmodel_best.pth.tar'
        use_loc = False
        use_trk = False

    model = STANet(use_loc, use_trk).cuda()
    model = load_pretrained_model(model_name, model)

    criterion = nn.MSELoss(size_average=False).cuda()
    with torch.no_grad():
        validate(test_pair, model, criterion)
Beispiel #5
0
def main():
    in_args = trian_args()
    train_datasets, trainloaders, validloaders,  testloaders=transform_data(in_args.data_dir)
    model = load_pretrained_model(in_args.model_name)
    model = replace_classifier(model, in_args.model_name, in_args.hidden_units)

    train_network(model, trainloaders, validloaders, in_args.epochs, 20,in_args.learning_rate, in_args.gpu)
    check_accuracy_on_test(model, testloaders, in_args.gpu)
    save_checkpoint(model, in_args.model_name, train_datasets, in_args.save_dir)
    def __init__(self, config):

        # Config
        self.config = config

        self.start = 0  # Unless using pre-trained model

        # Create directories if not exist
        utils.make_folder(self.config.save_path)
        utils.make_folder(self.config.model_weights_path)
        utils.make_folder(self.config.sample_images_path)

        # Copy files
        utils.write_config_to_file(self.config, self.config.save_path)
        utils.copy_scripts(self.config.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.config.batch_size_in_gpu, self.config.dataset,
            self.config.data_path, self.config.shuffle, self.config.drop_last,
            self.config.dataloader_args, self.config.resize,
            self.config.imsize, self.config.centercrop,
            self.config.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        # Start with pretrained model (if it exists)
        if self.config.pretrained_model != '':
            utils.load_pretrained_model(self)

        if self.config.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()
def exp1(opt):
    model = getattr(models.concrete.single, opt.model)(opt).to(device)
    opt.exp_name += opt.model
    vd = VisionDataset(opt, class_order=list(range(10)))

    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    logger = get_logger(folder=opt.log_dir + '/' + opt.exp_name + '/')
    logger.info(f'Running with device {device}')
    logger.info("==> Opts for this training: " + str(opt))

    trainer = Trainer(opt, logger, device=device)

    # pretraining
    if opt.num_pretrain_classes > 0:
        try:
            logger.info('Trying to load pretrained model...')
            model = load_pretrained_model(opt, model, logger)
            pretrain = False
        except Exception as e:
            logger.info(f'Failed to load pretrained model: {e}')
            pretrain = True

        if pretrain:
            assert opt.num_pretrain_passes > 0
            logger.info(f'==> Starting pretraining')
            for epoch in range(1, opt.num_pretrain_passes + 1):
                trainer.train(loader=vd.pretrain_loader, model=model, optimizer=optimizer, epoch=epoch)
                acc = trainer.test(loader=vd.pretest_loader, model=model, mask=vd.pretrain_mask, epoch_or_phase=epoch)
            logger.info(f'==> Pretraining completed! Acc: [{acc:.3f}]')
            save_pretrained_model(opt, model)

    if opt.num_tasks > 0:
        # TODO: use another optimizer?
        # Class-Incremental training
        # We start with pretrain mask bvecause in testing we want pretrained classes included
        logger.info(f'==> Starting Class-Incremental training')
        mask = vd.pretrain_mask.clone() if opt.num_pretrain_classes > 0 else torch.zeros(vd.n_classes_in_whole_dataset)
        dataloaders = vd.get_ci_dataloaders()
        cl_accuracy_meter = AverageMeter()
        for phase, (trainloader, testloader, class_list, phase_mask) in enumerate(dataloaders, start=1):
            trainer.train(loader=trainloader, model=model, optimizer=optimizer, phase=phase)

            # accumulate masks, because we want to test on all seen classes
            mask += phase_mask

            # this is the accuracy for all classes seen so far
            acc = trainer.test(loader=testloader, model=model, mask=mask, epoch_or_phase=phase)
            cl_accuracy_meter.update(acc)

        logger.info(f'==> CL training completed! AverageAcc: [{cl_accuracy_meter.avg:.3f}]')
Beispiel #8
0
def init_model(args, num_train_pids):

    print("Initializing model: {}".format(args.arch))
    if args.arch.lower() =='resnet50':
        model = ResNet50TP(num_classes=num_train_pids)
    elif args.arch.lower() =='alexnet':
        model = AlexNet(num_classes=num_train_pids)
    else:
        assert False, 'unknown model ' + args.arch

    # pretrained model loading
    if args.pretrained_model is not None:
        model = load_pretrained_model(model, args.pretrained_model)
    
    return model
Beispiel #9
0
def init_model_rl_training(args, num_train_pids):
    base_model = init_model(args, num_train_pids)

    if args.rl_algo == 'ql':
        print('creating agent for Q learning')
        agent_model = Agent_QL(base_model, args)
    elif args.rl_algo == 'pg':
        print('creating agent for Policy gradient')
        agent_model = Agent_PG(base_model, args)
    else:
        assert False, 'unknown rl algo ' + args.rl_algo

    # pretrained model loading
    if args.pretrained_model_rl is not None:
        agent_model = load_pretrained_model(agent_model, args.pretrained_model_rl)
    
    return agent_model
Beispiel #10
0
def load_checkpoint(checkpoint_name):
    import os
    save_dir = './checkpoint/'
    #if checkpoint is not absolute path, add default path.
    if os.path.isfile(checkpoint_name):
        filepath = checkpoint_name
    else:
        filepath = save_dir + checkpoint_name
    checkpoint = torch.load(filepath)

    # The type of architecture is being used for the loaded checkpoint
    model_name = checkpoint['architecture']

    # Load the appropriate pre-trained model
    model = load_pretrained_model(model_name)

    # Assign values from the checkpoint to the model
    model.class_to_idx = checkpoint['image_dataset']
    model.classifier = checkpoint['classifier']
    model.load_state_dict(checkpoint['state_dict'])
    return model
Beispiel #11
0
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info('----------- Network Initialization --------------')
    snet = define_tsnet(name=args.s_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.s_init)
    load_pretrained_model(snet, checkpoint['net'])
    logging.info('Student: %s', snet)
    logging.info('Student param size = %fMB', count_parameters_in_MB(snet))

    tnet = define_tsnet(name=args.t_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.t_model)
    load_pretrained_model(tnet, checkpoint['net'])
    tnet.eval()
    for param in tnet.parameters():
        param.requires_grad = False
    logging.info('Teacher: %s', tnet)
    logging.info('Teacher param size = %fMB', count_parameters_in_MB(tnet))
    logging.info('-----------------------------------------------')

    # initialize optimizer
    optimizer = torch.optim.SGD(snet.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # define attacker
    attacker = BSSAttacker(step_alpha=0.3, num_steps=10, eps=1e-4)

    # define loss functions
    criterionKD = BSS(args.T)
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('Invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    # warp nets and criterions for train and test
    nets = {'snet': snet, 'tnet': tnet}
    criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD}

    best_top1 = 0
    best_top5 = 0
    for epoch in range(1, args.epochs + 1):
        adjust_lr(optimizer, epoch)

        # train one epoch
        epoch_start_time = time.time()
        train(train_loader, nets, optimizer, criterions, attacker, epoch)

        # evaluate on testing set
        logging.info('Testing the models......')
        test_top1, test_top5 = test(test_loader, nets, criterions, epoch)

        epoch_duration = time.time() - epoch_start_time
        logging.info('Epoch time: {}s'.format(int(epoch_duration)))

        # save model
        is_best = False
        if test_top1 > best_top1:
            best_top1 = test_top1
            best_top5 = test_top5
            is_best = True
        logging.info('Saving models......')
        save_checkpoint(
            {
                'epoch': epoch,
                'snet': snet.state_dict(),
                'tnet': tnet.state_dict(),
                'prec@1': test_top1,
                'prec@5': test_top5,
            }, is_best, args.save_root)
Beispiel #12
0
import sys

import utils

from parameters import *
from sagan_models import Generator, Discriminator

if __name__ == '__main__':
    config = get_parameters()
    config.command = 'python ' + ' '.join(sys.argv)
    print(config)
    utils.check_for_CUDA(config)

    # Load pretrained model (if provided)
    if config.pretrained_model != '':
        utils.load_pretrained_model(config)
    else:
        assert config.num_of_classes, "Please provide number of classes! Eg. python3 test.py --num_of_classes 10"
        config.G = Generator(config.z_dim, config.g_conv_dim,
                             config.num_of_classes).to(config.device)
        config.D = Discriminator(config.d_conv_dim,
                                 config.num_of_classes).to(config.device)

    config.G.eval()
    config.D.eval()
    print(config.G, config.D)
Beispiel #13
0
if args.s_model == 'CNNRIS':
    snet = CNN_RIS()
else:
    raise Exception('Invalid name of the student network...')

if args.t_model == 'Teacher':
    tnet = Teacher()
elif args.t_model == 'Teacher1':
    tnet = Teacher1()
elif args.t_model == 'Teacher3':
    tnet = Teacher3()
else:
    raise Exception('Invalid name of the teacher network...')
tcheckpoint = torch.load(os.path.join('results/' + args.data_name+ '_'  \
 + args.t_model+ '_' + str(args.augmentation),'Best_Teacher_model.t7'))
load_pretrained_model(tnet, tcheckpoint['tnet'])
try:
    print('best_Teacher_acc is ' + str(tcheckpoint['test_acc']))
except:
    print('best_Teacher_acc is ' + str(tcheckpoint['best_PrivateTest_acc']))

if args.distillation == 'SSDEAS':  #  Read student parameters of semi-supervised training
    student_path = os.path.join(args.save_root + args.data_name+ '_Student_' + str(args.augmentation) \
     + '_' + str(args.perTraining),'PrivateTest_model.t7')
    print('Starting load student parameters.........')
    scheckpoint = torch.load(student_path)
    load_pretrained_model(snet, scheckpoint['snet'])
    print('The path of scheckpoint is:   ' + str(student_path))
    print('best_Student_acc is ' + str(scheckpoint['test_acc']))
    print('best_Student_mAP is ' + str(scheckpoint['test_mAP']))
    print('best_Student_F1 is ' + str(scheckpoint['test_F1']))
Beispiel #14
0
def main():
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

    if args.linear_eval:
        args.save_dir = os.path.join(args.save_dir, 'linear_eval')
    elif args.freeze:
        args.save_dir = os.path.join(args.save_dir, 'freeze_conv1-4')
    else:
        args.save_dir = os.path.join(args.save_dir, 'finetun_all')
    args.save_dir = os.path.join(
        args.save_dir, 'gpus_{}_lr_{}_bs_{}_epochs_{}_pretrained_{}'.format(
            len(args.gpus.split(',')), args.lr, args.batch_size, args.epochs,
            args.pretrained.split('/')[-1]))

    if args.mixup:
        args.save_dir = args.save_dir + '_alpha_{}'.format(args.alpha)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    global best_acc1

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    print("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](num_classes=args.num_classes)

    if args.linear_eval:
        print("=> linear evaluation")
        # freeze all layers but the last fc
        for name, param in model.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                param.requires_grad = False
        # init the fc layer
        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()
    else:
        if args.freeze:
            print("=> freeze conv1-conv4")
            for name, param in model.named_parameters():
                if name not in ['fc.weight', 'fc.bias'
                                ] and 'layer4' not in name:
                    print(name)
                    param.requires_grad = False
            # init the fc layer
            model.fc.weight.data.normal_(mean=0.0, std=0.01)
            model.fc.bias.data.zero_()
        else:
            print("=> plain training")

    # optionally use pretrained weights
    if args.pretrained:
        if args.no_conv5:
            model = load_pretrained_model_no_conv5(model, args.pretrained)
        else:
            model = load_pretrained_model(model, args.pretrained)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    if args.linear_eval:
        # optimize only the linear classifier
        parameters = list(filter(lambda p: p.requires_grad,
                                 model.parameters()))
        assert len(parameters) == 2  # fc.weight, fc.bias

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.input_size == 112:
        print('112x112 input')
        # for 224 input
        train_transforms = transforms.Compose([
            transforms.Resize(size=128),
            # transforms.RandomResizedCrop(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=112),
            transforms.ToTensor(),
            normalize
        ])
        val_transforms = transforms.Compose([
            transforms.Resize(size=128),
            transforms.CenterCrop(size=112),
            transforms.ToTensor(), normalize
        ])
    elif args.input_size == 224:
        print('224x224 input')
        #for 224 input
        train_transforms = transforms.Compose([
            transforms.Resize(size=256),
            #transforms.RandomResizedCrop(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=224),
            transforms.ToTensor(),
            normalize
        ])
        val_transforms = transforms.Compose([
            transforms.Resize(size=256),
            transforms.CenterCrop(size=224),
            transforms.ToTensor(), normalize
        ])

    elif args.input_size == 448:
        print('448x448 input')
        #for 448 input
        train_transforms = transforms.Compose([
            transforms.Resize(size=448),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=448),
            transforms.ToTensor(), normalize
        ])
        val_transforms = transforms.Compose([
            transforms.Resize(size=448),
            transforms.CenterCrop(size=448),
            transforms.ToTensor(), normalize
        ])

    train_dataset = datasets.ImageFolder(traindir, train_transforms)
    val_dataset = datasets.ImageFolder(valdir, val_transforms)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    train_start = time.time()

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        if args.mixup:
            mixup_train(train_loader, model, criterion, optimizer, epoch, args)
        else:
            train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'acc1': acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_dir)

    print('best acc1', best_acc1)
    train_end = time.time()

    print('total training time elapses {} hours'.format(
        (train_end - train_start) / 3600.0))
Beispiel #15
0
    def __init__(self, config):

        # Images data path & Output path
        self.dataset = config.dataset
        self.data_path = config.data_path
        self.save_path = os.path.join(config.save_path, config.name)

        # Training settings
        self.batch_size = config.batch_size
        self.total_step = config.total_step
        self.d_steps_per_iter = config.d_steps_per_iter
        self.g_steps_per_iter = config.g_steps_per_iter
        self.d_lr = config.d_lr
        self.g_lr = config.g_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.inst_noise_sigma = config.inst_noise_sigma
        self.inst_noise_sigma_iters = config.inst_noise_sigma_iters
        self.start = 0  # Unless using pre-trained model

        # Image transforms
        self.shuffle = config.shuffle
        self.drop_last = config.drop_last
        self.resize = config.resize
        self.imsize = config.imsize
        self.centercrop = config.centercrop
        self.centercrop_size = config.centercrop_size
        self.tanh_scale = config.tanh_scale
        self.normalize = config.normalize

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.save_n_images = config.save_n_images
        self.max_frames_per_gif = config.max_frames_per_gif

        # Pretrained model
        self.pretrained_model = config.pretrained_model

        # Misc
        self.manual_seed = config.manual_seed
        self.disable_cuda = config.disable_cuda
        self.parallel = config.parallel
        self.dataloader_args = config.dataloader_args

        # Output paths
        self.model_weights_path = os.path.join(self.save_path,
                                               config.model_weights_dir)
        self.sample_path = os.path.join(self.save_path, config.sample_dir)

        # Model hyper-parameters
        self.adv_loss = config.adv_loss
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.lambda_gp = config.lambda_gp

        # Model name
        self.name = config.name

        # Create directories if not exist
        utils.make_folder(self.save_path)
        utils.make_folder(self.model_weights_path)
        utils.make_folder(self.sample_path)

        # Copy files
        utils.write_config_to_file(config, self.save_path)
        utils.copy_scripts(self.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.batch_size, self.dataset, self.data_path, self.shuffle,
            self.drop_last, self.dataloader_args, self.resize, self.imsize,
            self.centercrop, self.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        # Start with pretrained model (if it exists)
        if self.pretrained_model != '':
            utils.load_pretrained_model(self)

        if self.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()
def train(epochs=1, batchSize=2, lr=0.0001, device='cpu', accumulate=True, a_step=16, load_saved=False, file_path='./saved_best.pt', use_dtp=False, pretrained_model='./bert_pretrain_model/', tokenizer_model='bert-base-chinese', weighted_loss=False):
    device = device
    tokenizer = load_tokenizer(tokenizer_model)
    my_net = torch.load(file_path) if load_saved else Net(load_pretrained_model(pretrained_model))
    my_net.to(device, non_blocking=True)
    label_dict = dict()
    with open('./tianchi_datasets/label.json') as f:
        for line in f:
            label_dict = json.loads(line)
            break
    label_weights_dict = dict()
    with open('./tianchi_datasets/label_weights.json') as f:
        for line in f:
            label_weights_dict = json.loads(line)
            break
    ocnli_train = dict()
    with open('./tianchi_datasets/OCNLI/train.json') as f:
        for line in f:
            ocnli_train = json.loads(line)
            break
    ocnli_dev = dict()
    with open('./tianchi_datasets/OCNLI/dev.json') as f:
        for line in f:
            ocnli_dev = json.loads(line)
            break
    ocemotion_train = dict()
    with open('./tianchi_datasets/OCEMOTION/train.json') as f:
        for line in f:
            ocemotion_train = json.loads(line)
            break
    ocemotion_dev = dict()
    with open('./tianchi_datasets/OCEMOTION/dev.json') as f:
        for line in f:
            ocemotion_dev = json.loads(line)
            break
    tnews_train = dict()
    with open('./tianchi_datasets/TNEWS/train.json') as f:
        for line in f:
            tnews_train = json.loads(line)
            break
    tnews_dev = dict()
    with open('./tianchi_datasets/TNEWS/dev.json') as f:
        for line in f:
            tnews_dev = json.loads(line)
            break
    train_data_generator = Data_generator(ocnli_train, ocemotion_train, tnews_train, label_dict, device, tokenizer)
    dev_data_generator = Data_generator(ocnli_dev, ocemotion_dev, tnews_dev, label_dict, device, tokenizer)
    tnews_weights = torch.tensor(label_weights_dict['TNEWS']).to(device, non_blocking=True)
    ocnli_weights = torch.tensor(label_weights_dict['OCNLI']).to(device, non_blocking=True)
    ocemotion_weights = torch.tensor(label_weights_dict['OCEMOTION']).to(device, non_blocking=True)
    loss_object = Calculate_loss(label_dict, weighted=weighted_loss, tnews_weights=tnews_weights, ocnli_weights=ocnli_weights, ocemotion_weights=ocemotion_weights)
    optimizer=torch.optim.Adam(my_net.parameters(), lr=lr)
    best_dev_f1 = 0.0
    best_epoch = -1
    for epoch in range(epochs):
        my_net.train()
        train_loss = 0.0
        train_total = 0
        train_correct = 0
        train_ocnli_correct = 0
        train_ocemotion_correct = 0
        train_tnews_correct = 0
        train_ocnli_pred_list = []
        train_ocnli_gold_list = []
        train_ocemotion_pred_list = []
        train_ocemotion_gold_list = []
        train_tnews_pred_list = []
        train_tnews_gold_list = []
        cnt_train = 0
        while True:
            raw_data = train_data_generator.get_next_batch(batchSize)
            if raw_data == None:
                break
            data = dict()
            data['input_ids'] = raw_data['input_ids']
            data['token_type_ids'] = raw_data['token_type_ids']
            data['attention_mask'] = raw_data['attention_mask']
            data['ocnli_ids'] = raw_data['ocnli_ids']
            data['ocemotion_ids'] = raw_data['ocemotion_ids']
            data['tnews_ids'] = raw_data['tnews_ids']
            tnews_gold = raw_data['tnews_gold']
            ocnli_gold = raw_data['ocnli_gold']
            ocemotion_gold = raw_data['ocemotion_gold']
            if not accumulate:
                optimizer.zero_grad()
            ocnli_pred, ocemotion_pred, tnews_pred = my_net(**data)
            if use_dtp:
                tnews_kpi = 0.1 if len(train_tnews_pred_list) == 0 else train_tnews_correct / len(train_tnews_pred_list)
                ocnli_kpi = 0.1 if len(train_ocnli_pred_list) == 0 else train_ocnli_correct / len(train_ocnli_pred_list)
                ocemotion_kpi = 0.1 if len(train_ocemotion_pred_list) == 0 else train_ocemotion_correct / len(train_ocemotion_pred_list)
                current_loss = loss_object.compute_dtp(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold,
                                                   ocemotion_gold, tnews_kpi, ocnli_kpi, ocemotion_kpi)
            else:
                current_loss = loss_object.compute(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
            train_loss += current_loss.item()
            current_loss.backward()
            if accumulate and (cnt_train + 1) % a_step == 0:
                optimizer.step()
                optimizer.zero_grad()
            if not accumulate:
                optimizer.step()
            if use_dtp:
                good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb = loss_object.correct_cnt_each(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
                tmp_good = sum([good_tnews_nb, good_ocnli_nb, good_ocemotion_nb])
                tmp_total = sum([total_tnews_nb, total_ocnli_nb, total_ocemotion_nb])
                train_ocemotion_correct += good_ocemotion_nb
                train_ocnli_correct += good_ocnli_nb
                train_tnews_correct += good_tnews_nb
            else:
                tmp_good, tmp_total = loss_object.correct_cnt(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
            train_correct += tmp_good
            train_total += tmp_total
            p, g = loss_object.collect_pred_and_gold(ocnli_pred, ocnli_gold)
            train_ocnli_pred_list += p
            train_ocnli_gold_list += g
            p, g = loss_object.collect_pred_and_gold(ocemotion_pred, ocemotion_gold)
            train_ocemotion_pred_list += p
            train_ocemotion_gold_list += g
            p, g = loss_object.collect_pred_and_gold(tnews_pred, tnews_gold)
            train_tnews_pred_list += p
            train_tnews_gold_list += g
            cnt_train += 1
            #torch.cuda.empty_cache()
            if (cnt_train + 1) % 1000 == 0:
                print('[', cnt_train + 1, '- th batch : train acc is:', train_correct / train_total, '; train loss is:', train_loss / cnt_train, ']')
        if accumulate:
            optimizer.step()
        optimizer.zero_grad()
        train_ocnli_f1 = get_f1(train_ocnli_gold_list, train_ocnli_pred_list)
        train_ocemotion_f1 = get_f1(train_ocemotion_gold_list, train_ocemotion_pred_list)
        train_tnews_f1 = get_f1(train_tnews_gold_list, train_tnews_pred_list)
        train_avg_f1 = (train_ocnli_f1 + train_ocemotion_f1 + train_tnews_f1) / 3
        print(epoch, 'th epoch train average f1 is:', train_avg_f1)
        print(epoch, 'th epoch train ocnli is below:')
        print_result(train_ocnli_gold_list, train_ocnli_pred_list)
        print(epoch, 'th epoch train ocemotion is below:')
        print_result(train_ocemotion_gold_list, train_ocemotion_pred_list)
        print(epoch, 'th epoch train tnews is below:')
        print_result(train_tnews_gold_list, train_tnews_pred_list)
        
        train_data_generator.reset()
        
        my_net.eval()
        dev_loss = 0.0
        dev_total = 0
        dev_correct = 0
        dev_ocnli_correct = 0
        dev_ocemotion_correct = 0
        dev_tnews_correct = 0
        dev_ocnli_pred_list = []
        dev_ocnli_gold_list = []
        dev_ocemotion_pred_list = []
        dev_ocemotion_gold_list = []
        dev_tnews_pred_list = []
        dev_tnews_gold_list = []
        cnt_dev = 0
        with torch.no_grad():
            while True:
                raw_data = dev_data_generator.get_next_batch(batchSize)
                if raw_data == None:
                    break
                data = dict()
                data['input_ids'] = raw_data['input_ids']
                data['token_type_ids'] = raw_data['token_type_ids']
                data['attention_mask'] = raw_data['attention_mask']
                data['ocnli_ids'] = raw_data['ocnli_ids']
                data['ocemotion_ids'] = raw_data['ocemotion_ids']
                data['tnews_ids'] = raw_data['tnews_ids']
                tnews_gold = raw_data['tnews_gold']
                ocnli_gold = raw_data['ocnli_gold']
                ocemotion_gold = raw_data['ocemotion_gold']
                ocnli_pred, ocemotion_pred, tnews_pred = my_net(**data)
                if use_dtp:
                    tnews_kpi = 0.1 if len(dev_tnews_pred_list) == 0 else dev_tnews_correct / len(
                        dev_tnews_pred_list)
                    ocnli_kpi = 0.1 if len(dev_ocnli_pred_list) == 0 else dev_ocnli_correct / len(
                        dev_ocnli_pred_list)
                    ocemotion_kpi = 0.1 if len(dev_ocemotion_pred_list) == 0 else dev_ocemotion_correct / len(
                        dev_ocemotion_pred_list)
                    current_loss = loss_object.compute_dtp(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold,
                                                           ocnli_gold,
                                                           ocemotion_gold, tnews_kpi, ocnli_kpi, ocemotion_kpi)
                else:
                    current_loss = loss_object.compute(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
                dev_loss += current_loss.item()
                if use_dtp:
                    good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb = loss_object.correct_cnt_each(
                        tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
                    tmp_good += sum([good_tnews_nb, good_ocnli_nb, good_ocemotion_nb])
                    tmp_total += sum([total_tnews_nb, total_ocnli_nb, total_ocemotion_nb])
                    dev_ocemotion_correct += good_ocemotion_nb
                    dev_ocnli_correct += good_ocnli_nb
                    dev_tnews_correct += good_tnews_nb
                else:
                    tmp_good, tmp_total = loss_object.correct_cnt(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
                dev_correct += tmp_good
                dev_total += tmp_total
                p, g = loss_object.collect_pred_and_gold(ocnli_pred, ocnli_gold)
                dev_ocnli_pred_list += p
                dev_ocnli_gold_list += g
                p, g = loss_object.collect_pred_and_gold(ocemotion_pred, ocemotion_gold)
                dev_ocemotion_pred_list += p
                dev_ocemotion_gold_list += g
                p, g = loss_object.collect_pred_and_gold(tnews_pred, tnews_gold)
                dev_tnews_pred_list += p
                dev_tnews_gold_list += g
                cnt_dev += 1
                #torch.cuda.empty_cache()
                #if (cnt_dev + 1) % 1000 == 0:
                #    print('[', cnt_dev + 1, '- th batch : dev acc is:', dev_correct / dev_total, '; dev loss is:', dev_loss / cnt_dev, ']')
            dev_ocnli_f1 = get_f1(dev_ocnli_gold_list, dev_ocnli_pred_list)
            dev_ocemotion_f1 = get_f1(dev_ocemotion_gold_list, dev_ocemotion_pred_list)
            dev_tnews_f1 = get_f1(dev_tnews_gold_list, dev_tnews_pred_list)
            dev_avg_f1 = (dev_ocnli_f1 + dev_ocemotion_f1 + dev_tnews_f1) / 3
            print(epoch, 'th epoch dev average f1 is:', dev_avg_f1)
            print(epoch, 'th epoch dev ocnli is below:')
            print_result(dev_ocnli_gold_list, dev_ocnli_pred_list)
            print(epoch, 'th epoch dev ocemotion is below:')
            print_result(dev_ocemotion_gold_list, dev_ocemotion_pred_list)
            print(epoch, 'th epoch dev tnews is below:')
            print_result(dev_tnews_gold_list, dev_tnews_pred_list)

            dev_data_generator.reset()
            
            if dev_avg_f1 > best_dev_f1:
                best_dev_f1 = dev_avg_f1
                best_epoch = epoch
                torch.save(my_net, file_path)
            print('best epoch is:', best_epoch, '; with best f1 is:', best_dev_f1)
Beispiel #17
0
        # output (without dropout)
        x = F.log_softmax(self.fc3(x), dim=1)

        return x


#-------------------------------------------------------------------------------
# get model
#-------------------------------------------------------------------------------

# load data
train_loader, valid_loader, test_loader, class_to_idx, image_datasets = utils.get_loaders(
    data_dir=args.data_dir)

# load pretrained network
model = utils.load_pretrained_model(arch=args.arch)

#-------------------------------------------------------------------------------
# set hyperparams for model training
#-------------------------------------------------------------------------------
for param in model.parameters():
    param.requires_grad = False
input_size = model.classifier[0].in_features
hidden_size = args.hidden_units
output_size = 102
p_dropout = 0.5
model.classifier = Classifier()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=args.learning_rate)

#-------------------------------------------------------------------------------
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("=> creating model '{}'".format(args.arch))

    model = simclr.builder.SimCLR(models.__dict__[args.arch], args.simclr_dim,
                                  args.mlp)
    print(model)

    args.warmup_epochs = 0
    # warm-up for large-batch training,
    if args.batch_size > 256:
        args.warm = True
    if args.warm:
        args.warmup_from = 0.01
        args.warmup_epochs = 10

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        #raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        model = torch.nn.DataParallel(model)
        model = model.cuda(args.gpu)
        #raise NotImplementedError("Only DistributedDataParallel is supported.")

    # define loss function (criterion) and optimizer
    criterion = simclr.losses.SupConLoss(temperature=args.simclr_t).cuda(
        args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # optionally resume from a checkpoint
    if args.pretrained:
        model = load_pretrained_model(model, args.pretrained)
    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.aug_plus:
        # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
        augmentation = [
            transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.)),
            transforms.RandomApply(
                [
                    transforms.ColorJitter(0.4, 0.4, 0.4,
                                           0.1)  # not strengthened
                ],
                p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([simclr.loader.GaussianBlur([.1, 2.])],
                                   p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
    else:
        # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
        augmentation = [
            transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]

    train_dataset = datasets.ImageFolder(
        traindir,
        simclr.loader.TwoCropsTransform(transforms.Compose(augmentation)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    train_start = time.time()

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        #adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if epoch % 100 == 0 or epoch == args.epochs - 1:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    },
                    is_best=False,
                    root=args.save_dir,
                    filename='checkpoint_{:04d}.pth.tar'.format(epoch))
    train_end = time.time()

    print('total training time elapses {} hours'.format(
        (train_end - train_start) / 3600.0))
Beispiel #19
0
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info('----------- Network Initialization --------------')
    net1 = define_tsnet(name=args.net1_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.net1_init)
    load_pretrained_model(net1, checkpoint['net'])
    logging.info('Net1: %s', net1)
    logging.info('Net1 param size = %fMB', count_parameters_in_MB(net1))

    net2 = define_tsnet(name=args.net2_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.net2_init)
    load_pretrained_model(net2, checkpoint['net'])
    logging.info('Net2: %s', net2)
    logging.info('Net2 param size = %fMB', count_parameters_in_MB(net2))
    logging.info('-----------------------------------------------')

    # initialize optimizer
    optimizer1 = torch.optim.SGD(net1.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)
    optimizer2 = torch.optim.SGD(net2.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)

    # define loss functions
    criterionKD = DML()
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('Invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    # warp nets and criterions for train and test
    nets = {'net1': net1, 'net2': net2}
    criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD}
    optimizers = {'optimizer1': optimizer1, 'optimizer2': optimizer2}

    best_top1 = 0
    best_top5 = 0
    for epoch in range(1, args.epochs + 1):
        adjust_lr(optimizers, epoch)

        # train one epoch
        epoch_start_time = time.time()
        train(train_loader, nets, optimizers, criterions, epoch)

        # evaluate on testing set
        logging.info('Testing the models......')
        test_top11, test_top15, test_top21, test_top25 = test(
            test_loader, nets, criterions)

        epoch_duration = time.time() - epoch_start_time
        logging.info('Epoch time: {}s'.format(int(epoch_duration)))

        # save model
        is_best = False
        if max(test_top11, test_top21) > best_top1:
            best_top1 = max(test_top11, test_top21)
            best_top5 = max(test_top15, test_top25)
            is_best = True
        logging.info('Saving models......')
        save_checkpoint(
            {
                'epoch': epoch,
                'net1': net1.state_dict(),
                'net2': net2.state_dict(),
                'prec1@1': test_top11,
                'prec1@5': test_top15,
                'prec2@1': test_top21,
                'prec2@5': test_top25,
            }, is_best, args.save_root)
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info('----------- Network Initialization --------------')
    snet = define_tsnet(name=args.s_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.s_init)
    load_pretrained_model(snet, checkpoint['net'])
    logging.info('Student: %s', snet)
    logging.info('Student param size = %fMB', count_parameters_in_MB(snet))

    tnet = define_tsnet(name=args.t_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.t_model)
    load_pretrained_model(tnet, checkpoint['net'])
    tnet.eval()
    for param in tnet.parameters():
        param.requires_grad = False
    logging.info('Teacher: %s', tnet)
    logging.info('Teacher param size = %fMB', count_parameters_in_MB(tnet))
    logging.info('-----------------------------------------------')

    # define loss functions
    if args.kd_mode == 'logits':
        criterionKD = Logits()
    elif args.kd_mode == 'st':
        criterionKD = SoftTarget(args.T)
    elif args.kd_mode == 'at':
        criterionKD = AT(args.p)
    elif args.kd_mode == 'fitnet':
        criterionKD = Hint()
    elif args.kd_mode == 'nst':
        criterionKD = NST()
    elif args.kd_mode == 'pkt':
        criterionKD = PKTCosSim()
    elif args.kd_mode == 'fsp':
        criterionKD = FSP()
    elif args.kd_mode == 'rkd':
        criterionKD = RKD(args.w_dist, args.w_angle)
    elif args.kd_mode == 'ab':
        criterionKD = AB(args.m)
    elif args.kd_mode == 'sp':
        criterionKD = SP()
    elif args.kd_mode == 'sobolev':
        criterionKD = Sobolev()
    elif args.kd_mode == 'cc':
        criterionKD = CC(args.gamma, args.P_order)
    elif args.kd_mode == 'lwm':
        criterionKD = LwM()
    elif args.kd_mode == 'irg':
        criterionKD = IRG(args.w_irg_vert, args.w_irg_edge, args.w_irg_tran)
    elif args.kd_mode == 'vid':
        s_channels = snet.module.get_channel_num()[1:4]
        t_channels = tnet.module.get_channel_num()[1:4]
        criterionKD = []
        for s_c, t_c in zip(s_channels, t_channels):
            criterionKD.append(VID(s_c, int(args.sf * t_c), t_c,
                                   args.init_var))
        criterionKD = [c.cuda()
                       for c in criterionKD] if args.cuda else criterionKD
        criterionKD = [None] + criterionKD  # None is a placeholder
    elif args.kd_mode == 'ofd':
        s_channels = snet.module.get_channel_num()[1:4]
        t_channels = tnet.module.get_channel_num()[1:4]
        criterionKD = []
        for s_c, t_c in zip(s_channels, t_channels):
            criterionKD.append(
                OFD(s_c, t_c).cuda() if args.cuda else OFD(s_c, t_c))
        criterionKD = [None] + criterionKD  # None is a placeholder
    elif args.kd_mode == 'afd':
        # t_channels is same with s_channels
        s_channels = snet.module.get_channel_num()[1:4]
        t_channels = tnet.module.get_channel_num()[1:4]
        criterionKD = []
        for t_c in t_channels:
            criterionKD.append(
                AFD(t_c, args.att_f).cuda() if args.
                cuda else AFD(t_c, args.att_f))
        criterionKD = [None] + criterionKD  # None is a placeholder
        # # t_chws is same with s_chws
        # s_chws = snet.module.get_chw_num()[1:4]
        # t_chws = tnet.module.get_chw_num()[1:4]
        # criterionKD = []
        # for t_chw in t_chws:
        # 	criterionKD.append(AFD(t_chw).cuda() if args.cuda else AFD(t_chw))
        # criterionKD = [None] + criterionKD # None is a placeholder
    else:
        raise Exception('Invalid kd mode...')
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()

    # initialize optimizer
    if args.kd_mode in ['vid', 'ofd', 'afd']:
        optimizer = torch.optim.SGD(chain(
            snet.parameters(), *[c.parameters() for c in criterionKD[1:]]),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    else:
        optimizer = torch.optim.SGD(snet.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('Invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    # warp nets and criterions for train and test
    nets = {'snet': snet, 'tnet': tnet}
    criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD}

    # first initilizing the student nets
    if args.kd_mode in ['fsp', 'ab']:
        logging.info('The first stage, student initialization......')
        train_init(train_loader, nets, optimizer, criterions, 50)
        args.lambda_kd = 0.0
        logging.info('The second stage, softmax training......')

    best_top1 = 0
    best_top5 = 0
    for epoch in range(1, args.epochs + 1):
        adjust_lr(optimizer, epoch)

        # train one epoch
        epoch_start_time = time.time()
        train(train_loader, nets, optimizer, criterions, epoch)

        # evaluate on testing set
        logging.info('Testing the models......')
        test_top1, test_top5 = test(test_loader, nets, criterions, epoch)

        epoch_duration = time.time() - epoch_start_time
        logging.info('Epoch time: {}s'.format(int(epoch_duration)))

        # save model
        is_best = False
        if test_top1 > best_top1:
            best_top1 = test_top1
            best_top5 = test_top5
            is_best = True
        logging.info('Saving models......')
        save_checkpoint(
            {
                'epoch': epoch,
                'snet': snet.state_dict(),
                'tnet': tnet.state_dict(),
                'prec@1': test_top1,
                'prec@5': test_top5,
            }, is_best, args.save_root)
Beispiel #21
0
def main():
    global global_step
    global best_prec1

    start_time = time.time()
    dataset_config = voc2007()
    train_dataset = dataset_config.get('train_dataset')
    val_dataset = dataset_config.get('val_dataset')
    num_classes = dataset_config.get('num_classes')
    #train_loader = create_train_loader(train_dataset, args=args)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               pin_memory=True,
                                               drop_last=True)

    if val_dataset is not None:
        eval_loader = torch.utils.data.DataLoader(val_dataset,
                                                  batch_size=eval_batch_size,
                                                  shuffle=False,
                                                  num_workers=8,
                                                  pin_memory=True,
                                                  drop_last=False)
    else:
        eval_loader = None
    print("=> load dataset in {} seconds".format(time.time() - start_time))

    #if 'voc' in args.dataset:
    validate = validate_voc
    #else:
    #validate = validate_coco

    print("=> creating model ")

    model = models.__dict__[arch](num_classes=num_classes)

    #print(parameters_string(model))

    model = load_pretrained_model(model, pretrained)
    model = model.cuda()
    model = nn.DataParallel(model)

    class_criterion = nn.BCEWithLogitsLoss()

    if finetune_fc:
        print('=> Finetune only FC layer')
        paras = model.module.fc.parameters()

        #print('=> Finetune only FC + layer4')
        #paras = [{'params': model.module.fc.parameters(),
        #          'params': model.module.layer4.parameters()}]
    else:
        print('=> Training all layers')
        paras = model.parameters()

    print('start learning rate ', lr)
    optimizer = torch.optim.SGD(paras,
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    cudnn.benchmark = True

    if evaluate:
        results_dir = './predict'
        validate_voc_file(eval_loader, model, thre, 0, print_freq, results_dir)
        #validate(eval_loader, model, args.thre, 0, context.vis_log, LOG, args.print_freq)

    for epoch in range(epochs):

        # train for one epoch
        train(train_loader, model, class_criterion, optimizer, epoch)

        if evaluation_epochs and (
                epoch +
                1) % evaluation_epochs == 0 and eval_loader is not None:

            prec1 = validate(eval_loader, model, thre, epoch, print_freq)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
        else:
            is_best = False

        if checkpoint_epochs and (epoch + 1) % checkpoint_epochs == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'global_step': global_step,
                    'arch': arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, checkpoint_path, epoch + 1)

    save_checkpoint(
        {
            'epoch': epoch + 1,
            'global_step': global_step,
            'arch': arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, False, checkpoint_path, 'final')
    print("best_prec1 {}".format(best_prec1))
Beispiel #22
0
    # Resume from a snapshot
    if args.resume:
        logging.warning("resumed from %s" % args.resume)
        torch_load(args.resume, model, optimizer)
        setattr(args, "start_epoch", int(args.resume.split('.')[-1]) + 1)
    else:
        setattr(args, "start_epoch", 1)

    if args.load_pretrained_model:
        model_path, modules_to_load, exclude_modules = args.load_pretrained_model.split(
            ":")
        logging.warning("load pretrained model from %s" %
                        args.load_pretrained_model)
        load_pretrained_model(model=model,
                              model_path=model_path,
                              modules_to_load=modules_to_load,
                              exclude_modules=exclude_modules)
    if args.load_head_from_pretrained_model:
        logging.warning("load pretrained model head from %s" %
                        args.load_head_from_pretrained_model)
        load_head_from_pretrained_model(
            model=model, model_path=args.load_head_from_pretrained_model)

    logging.warning("Total parameter of the model = " +
                    str(sum(p.numel() for p in model.parameters())))
    logging.warning("Trainable parameter of the model = " + str(
        sum(p.numel()
            for p in filter(lambda x: x.requires_grad, model.parameters()))))

    if args.ngpu > 1 and args.dist_train:
        model = torch.nn.parallel.DistributedDataParallel(