Exemple #1
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            if args.init_type == "normal":
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == "orth":
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == "xavier_uniform":
                nn.init.xavier_uniform(m.weight.data, 1.0)
            else:
                raise NotImplementedError("{} unknown inital type".format(
                    args.init_type))
        elif classname.find("BatchNorm2d") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # set grow controller
    grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f"=> resuming from {args.load_path}")
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, "Model",
                                       "checkpoint.pth")
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        # set controller && its optimizer
        cur_stage = checkpoint["cur_stage"]
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

        start_search_iter = checkpoint["search_iter"]
        gen_net.load_state_dict(checkpoint["gen_state_dict"])
        dis_net.load_state_dict(checkpoint["dis_state_dict"])
        controller.load_state_dict(checkpoint["ctrl_state_dict"])
        gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ctrl_optimizer.load_state_dict(checkpoint["ctrl_optimizer"])
        prev_archs = checkpoint["prev_archs"]
        prev_hiddens = checkpoint["prev_hiddens"]

        args.path_helper = checkpoint["path_helper"]
        logger = create_logger(args.path_helper["log_path"])
        logger.info(
            f"=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})"
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir("logs", args.exp_name)
        logger = create_logger(args.path_helper["log_path"])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
    train_loader = dataset.train

    logger.info(args)
    writer_dict = {
        "writer": SummaryWriter(args.path_helper["log_path"]),
        "controller_steps": start_search_iter * args.ctrl_step,
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop
    for search_iter in tqdm(range(int(start_search_iter),
                                  int(args.max_search_iter)),
                            desc="search progress"):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter == args.grow_step1 or search_iter == args.grow_step2:

            # save
            cur_stage = grow_ctrler.cur_stage(search_iter)
            logger.info(f"=> grow to stage {cur_stage}")
            prev_archs, prev_hiddens = get_topk_arch_hidden(
                args, controller, gen_net, prev_archs, prev_hiddens)

            # grow section
            del controller
            del ctrl_optimizer
            controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                       weights_init)

            dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
            train_loader = dataset.train

        dynamic_reset = train_shared(
            args,
            gen_net,
            dis_net,
            g_loss_history,
            d_loss_history,
            controller,
            gen_optimizer,
            dis_optimizer,
            train_loader,
            prev_hiddens=prev_hiddens,
            prev_archs=prev_archs,
        )
        train_controller(
            args,
            controller,
            ctrl_optimizer,
            gen_net,
            prev_hiddens,
            prev_archs,
            writer_dict,
        )

        if dynamic_reset:
            logger.info("re-initialize share GAN")
            del gen_net, dis_net, gen_optimizer, dis_optimizer
            gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
                args, weights_init)

        save_checkpoint(
            {
                "cur_stage": cur_stage,
                "search_iter": search_iter + 1,
                "gen_model": args.gen_model,
                "dis_model": args.dis_model,
                "controller": args.controller,
                "gen_state_dict": gen_net.state_dict(),
                "dis_state_dict": dis_net.state_dict(),
                "ctrl_state_dict": controller.state_dict(),
                "gen_optimizer": gen_optimizer.state_dict(),
                "dis_optimizer": dis_optimizer.state_dict(),
                "ctrl_optimizer": ctrl_optimizer.state_dict(),
                "prev_archs": prev_archs,
                "prev_hiddens": prev_hiddens,
                "path_helper": args.path_helper,
            },
            False,
            args.path_helper["ckpt_path"],
        )

    final_archs, _ = get_topk_arch_hidden(args, controller, gen_net,
                                          prev_archs, prev_hiddens)
    logger.info(f"discovered archs: {final_archs}")
Exemple #2
0
def main(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)
    device = torch.device('cuda')
    num_gpu = len(str(args.gpu).split(','))
    args.batch_size = num_gpu * args.batch_size

    ### model ###
    if args.model == 'memdpc':
        model = MemDPC_BD(sample_size=args.img_dim, 
                        num_seq=args.num_seq, 
                        seq_len=args.seq_len, 
                        network=args.net, 
                        pred_step=args.pred_step,
                        mem_size=args.mem_size)
    else: 
        raise NotImplementedError('wrong model!')

    model.to(device)
    model = nn.DataParallel(model)
    model_without_dp = model.module

    ### optimizer ###
    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    criterion = nn.CrossEntropyLoss()

    ### data ###
    transform = transforms.Compose([
        A.RandomSizedCrop(size=224, consistent=True, p=1.0), # crop from 256 to 224
        A.Scale(size=(args.img_dim,args.img_dim)),
        A.RandomHorizontalFlip(consistent=True),
        A.RandomGray(consistent=False, p=0.25),
        A.ColorJitter(0.5, 0.5, 0.5, 0.25, consistent=False, p=1.0),
        A.ToTensor(),
        A.Normalize()
    ])

    train_loader = get_data(transform, 'train')
    val_loader = get_data(transform, 'val')

    if 'ucf' in args.dataset: 
        lr_milestones_eps = [300,400]
    elif 'k400' in args.dataset: 
        lr_milestones_eps = [120,160]
    else: 
        lr_milestones_eps = [1000] # NEVER
    lr_milestones = [len(train_loader) * m for m in lr_milestones_eps]
    print('=> Use lr_scheduler: %s eps == %s iters' % (str(lr_milestones_eps), str(lr_milestones)))
    lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=lr_milestones, repeat=1)
    lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    best_acc = 0
    args.iteration = 1

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            model_without_dp.load_state_dict(checkpoint['state_dict'])
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                print('[WARNING] Not loading optimizer states')
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))
            sys.exit(0)

    # logging tools
    args.img_path, args.model_path = set_path(args)
    args.logger = Logger(path=args.img_path)
    args.logger.log('args=\n\t\t'+'\n\t\t'.join(['%s:%s'%(str(k),str(v)) for k,v in vars(args).items()]))

    args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val'))
    args.writer_train = SummaryWriter(logdir=os.path.join(args.img_path, 'train'))
    
    torch.backends.cudnn.benchmark = True

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        np.random.seed(epoch)
        random.seed(epoch)

        train_loss, train_acc = train_one_epoch(train_loader, 
                                                model, 
                                                criterion, 
                                                optimizer, 
                                                lr_scheduler, 
                                                device, 
                                                epoch, 
                                                args)
        val_loss, val_acc = validate(val_loader, 
                                     model, 
                                     criterion, 
                                     device, 
                                     epoch, 
                                     args)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_dict = {'epoch': epoch,
                     'state_dict': model_without_dp.state_dict(),
                     'best_acc': best_acc,
                     'optimizer': optimizer.state_dict(),
                     'iteration': args.iteration}
        save_checkpoint(save_dict, is_best, 
            filename=os.path.join(args.model_path, 'epoch%s.pth.tar' % str(epoch)), 
            keep_all=False)

    print('Training from ep %d to ep %d finished' 
        % (args.start_epoch, args.epochs))
    sys.exit(0)
Exemple #3
0
    def train(self, resume=False):
        global fold
        start_epoch = 0
        best_precision = 0
        if resume:
            checkpoint = torch.load(configs.best_models + configs.model_name +
                                    os.sep + str(fold) + '/model_best.pth.tar')
            start_epoch = checkpoint["epoch"]
            fold = checkpoint["fold"]
            best_precision = checkpoint["best_precision"]
            self.model.load_state_dict(checkpoint["state_dict"])
            self.opt.load_state_dict(checkpoint["optimizer"])

        total = len(get_train_data())
        for epoch in range(start_epoch, configs.epochs):
            train_progressor = ProgressBar(mode="Train",
                                           epoch=epoch,
                                           total_epoch=configs.epochs,
                                           model_name=configs.model_name,
                                           total=total)
            epoch_loss = 0
            start = time.time()
            epoch_acc = 0
            self.model.train()
            if epoch < 20:
                td = get_train_data()
            else:
                td = get_train_data(True)
            for i, (x, y) in enumerate(td):
                x = x.to(device)
                y = y.to(device)
                train_progressor.current = i
                self.opt.zero_grad()
                outputs = self.model(x)
                loss = self.criterion(y, outputs)
                loss.backward()
                self.opt.step()
                epoch_loss += loss.item()
                acc = self.acc.iou(y, outputs)
                epoch_acc += acc
                time_cost = time.time() - start
                train_progressor.current_loss = loss.item()
                train_progressor.current_acc = acc
                train_progressor()
            iter = total
            train_progressor.done(time_cost, epoch_loss / iter,
                                  epoch_acc / iter)
            # writer.add_scalar('avg_epoch_train_loss', epoch_loss / iter, epoch)
            # writer.add_scalar('avg_epoch_train_acc', epoch_acc / iter, epoch)
            val_loss, val_acc = self.evaluate(epoch)
            # writer.add_scalar('avg_epoch_val_loss', val_loss, epoch)
            # writer.add_scalar('avg_epoch_val_acc', val_acc, epoch)
            is_best = val_acc > best_precision
            best_precision = max(val_acc, best_precision)
            u.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "model_name": configs.model_name,
                    "state_dict": self.model.state_dict(),
                    "best_precision": best_precision,
                    "optimizer": self.opt.state_dict(),
                    "fold": fold,
                    "valid_loss": val_loss,
                    "valid_acc": val_acc,
                }, is_best, fold)
Exemple #4
0
def main():

    config.workspace = os.path.join(config.workspace_dir, config.exp_name)
    if config.restart_training:
        shutil.rmtree(config.workspace, ignore_errors=True)
    if not os.path.exists(config.workspace):
        os.makedirs(config.workspace)

    shutil.rmtree(os.path.join(config.workspace, 'train_log'),
                  ignore_errors=True)
    logger = setup_logger(os.path.join(config.workspace, 'train_log'))
    logger.info(config.print())

    torch.manual_seed(config.seed)  # 为CPU设置随机种子
    if config.gpu_id is not None and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        logger.info('train with gpu {} and pytorch {}'.format(
            config.gpu_id, torch.__version__))
        device = torch.device("cuda:0")
        torch.cuda.manual_seed(config.seed)  # 为当前GPU设置随机种子
        torch.cuda.manual_seed_all(config.seed)  # 为所有GPU设置随机种子
    else:
        logger.info('train with cpu and pytorch {}'.format(torch.__version__))
        device = torch.device("cpu")

    train_data = Synthtext(config.trainroot,
                           data_shape=config.data_shape,
                           n=config.n,
                           m=config.m)
    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=config.train_batch_size,
                                   shuffle=True,
                                   num_workers=int(config.workers))

    # writer = SummaryWriter(config.output_dir)
    model = PSENet(backbone=config.backbone,
                   pretrained=config.pretrained,
                   result_num=config.n,
                   scale=config.scale)
    if not config.pretrained and not config.restart_training:
        model.apply(weights_init)

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.DataParallel(model)
    model = model.to(device)
    # dummy_input = torch.autograd.Variable(torch.Tensor(1, 3, 600, 800).to(device))
    # writer.add_graph(models=models, input_to_model=dummy_input)
    criterion = PSELoss(Lambda=config.Lambda,
                        ratio=config.OHEM_ratio,
                        reduction='mean')
    # optimizer = torch.optim.SGD(models.parameters(), lr=config.lr, momentum=0.99)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    if config.checkpoint != '' and not config.restart_training:
        start_epoch = load_checkpoint(config.checkpoint, model, logger, device,
                                      optimizer)
        start_epoch += 1
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            config.lr_decay_step,
            gamma=config.lr_gamma,
            last_epoch=start_epoch)
    else:
        start_epoch = config.start_epoch
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         config.lr_decay_step,
                                                         gamma=config.lr_gamma)

    all_step = len(train_loader)
    logger.info('train dataset has {} samples,{} in dataloader'.format(
        train_data.__len__(), all_step))
    epoch = 0
    f1 = 0
    try:

        for epoch in range(start_epoch, config.epochs):
            start = time.time()
            train_loss, lr = train_epoch(model, optimizer, scheduler,
                                         train_loader, device, criterion,
                                         epoch, all_step, logger)
            logger.info(
                '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                    epoch, config.epochs, train_loss,
                    time.time() - start, lr))
            # net_save_path = '{}/PSENet_{}_loss{:.6f}.pth'.format(config.output_dir, epoch,
            #                                                                               train_loss)
            # save_checkpoint(net_save_path, models, optimizer, epoch, logger)
            if epoch < 100:
                best_save_path = '{}/epoch_{:.1f}_model.pth'.format(
                    config.workspace, epoch)

                save_checkpoint(best_save_path, model, optimizer, epoch,
                                logger)

                # f_score_new = eval(model, os.path.join(config.workspace, 'output'), config.testroot, device)
                # logger.info('  ---------------------------------------')
                # logger.info('     test: f_score : {:.6f}'.format(f_score_new))
                # logger.info('  ---------------------------------------')
                # net_save_path = '{}/PSENet_{}_loss{:.6f}_r{:.6f}.pth'.format(config.output_dir, epoch,
                #                                                                               train_loss,
                #                                                                               f_score_new)
                # save_checkpoint(net_save_path, model, optimizer, epoch, logger)
                # if f_score_new > f1:
                #     f1 = f_score_new
                #     best_save_path = '{}/Best_model_{:.6f}.pth'.format(config.workspace,f1)
                #
                #     save_checkpoint(best_save_path, model, optimizer, epoch, logger)

                # writer.add_scalar(tag='Test/recall', scalar_value=recall, global_step=epoch)
                # writer.add_scalar(tag='Test/precision', scalar_value=precision, global_step=epoch)
                # writer.add_scalar(tag='Test/f1', scalar_value=f1, global_step=epoch)
        # writer.close()
    except KeyboardInterrupt:
        save_checkpoint('{}/final.pth'.format(config.workspace), model,
                        optimizer, epoch, logger)
Exemple #5
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
        config, is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }
    #32*3*256*192
    dump_input = torch.rand(
        (config.TRAIN.BATCH_SIZE, 3, config.MODEL.IMAGE_SIZE[1],
         config.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)

    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    best_perf = 0.0
    best_model = False
    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        #test train_loader
        dataiter = train_dataset[0]
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
def main():

    args = parse_args()
    trainroot = args.trainroot
    testroot = args.testroot
    backbone = args.backbone
    print("trainroot:", trainroot)
    print("testroot:", testroot)
    print("backbone:", backbone)

    if config.output_dir is None:
        config.output_dir = 'output'
    if config.restart_training:
        shutil.rmtree(config.output_dir, ignore_errors=True)
    if not os.path.exists(config.output_dir):
        os.makedirs(config.output_dir)

    logger = setup_logger(os.path.join(config.output_dir, 'train_log'))
    logger.info(config.print())

    torch.manual_seed(config.seed)  # 为CPU设置随机种子
    if config.gpu_id is not None and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        logger.info('train with gpu {} and pytorch {}'.format(
            config.gpu_id, torch.__version__))
        device = torch.device("cuda:0")
        torch.cuda.manual_seed(config.seed)  # 为当前GPU设置随机种子
        torch.cuda.manual_seed_all(config.seed)  # 为所有GPU设置随机种子
    else:
        logger.info('train with cpu and pytorch {}'.format(torch.__version__))
        device = torch.device("cpu")

    train_data = MyDataset(trainroot,
                           data_shape=config.data_shape,
                           n=config.n,
                           m=config.m,
                           transform=transforms.ToTensor())

    print("len(train_data):", len(train_data))
    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=config.train_batch_size,
                                   shuffle=True,
                                   num_workers=int(config.workers))

    writer = SummaryWriter(config.output_dir)
    model = PSENet(backbone=backbone,
                   pretrained=config.pretrained,
                   result_num=config.n,
                   scale=config.scale)
    if not config.pretrained and not config.restart_training:
        model.apply(weights_init)

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.DataParallel(model)
    model = model.to(device)
    # dummy_input = torch.autograd.Variable(torch.Tensor(1, 3, 600, 800).to(device))
    # writer.add_graph(models=models, input_to_model=dummy_input)
    criterion = PSELoss(Lambda=config.Lambda,
                        ratio=config.OHEM_ratio,
                        reduction='mean')
    # optimizer = torch.optim.SGD(models.parameters(), lr=config.lr, momentum=0.99)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    if config.checkpoint != '' and not config.restart_training:
        start_epoch = load_checkpoint(config.checkpoint, model, logger, device,
                                      optimizer)
        start_epoch += 1
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            config.lr_decay_step,
            gamma=config.lr_gamma,
            last_epoch=start_epoch)
    else:
        start_epoch = config.start_epoch
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         config.lr_decay_step,
                                                         gamma=config.lr_gamma)

    all_step = len(train_loader)
    logger.info('train dataset has {} samples,{} in dataloader'.format(
        train_data.__len__(), all_step))
    epoch = 0
    best_model = {'recall': 0, 'precision': 0, 'f1': 0, 'models': ''}
    try:
        for epoch in range(start_epoch, config.epochs):
            start = time.time()
            train_loss, lr = train_epoch(model, optimizer, scheduler,
                                         train_loader, device, criterion,
                                         epoch, all_step, writer, logger)
            logger.info(
                '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                    epoch, config.epochs, train_loss,
                    time.time() - start, lr))
            # net_save_path = '{}/PSENet_{}_loss{:.6f}.pth'.format(config.output_dir, epoch,
            #                                                                               train_loss)
            # save_checkpoint(net_save_path, models, optimizer, epoch, logger)
            if (0.3 < train_loss < 0.4 and epoch % 4 == 0) or train_loss < 0.3:
                recall, precision, f1 = eval(
                    model, os.path.join(config.output_dir, 'output'), testroot,
                    device)
                logger.info(
                    'test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.
                    format(recall, precision, f1))

                net_save_path = '{}/PSENet_{}_loss{:.6f}_r{:.6f}_p{:.6f}_f1{:.6f}.pth'.format(
                    config.output_dir, epoch, train_loss, recall, precision,
                    f1)
                save_checkpoint(net_save_path, model, optimizer, epoch, logger)
                if f1 > best_model['f1']:
                    best_path = glob.glob(config.output_dir + '/Best_*.pth')
                    for b_path in best_path:
                        if os.path.exists(b_path):
                            os.remove(b_path)

                    best_model['recall'] = recall
                    best_model['precision'] = precision
                    best_model['f1'] = f1
                    best_model['models'] = net_save_path

                    best_save_path = '{}/Best_{}_r{:.6f}_p{:.6f}_f1{:.6f}.pth'.format(
                        config.output_dir, epoch, recall, precision, f1)
                    if os.path.exists(net_save_path):
                        shutil.copyfile(net_save_path, best_save_path)
                    else:
                        save_checkpoint(best_save_path, model, optimizer,
                                        epoch, logger)

                    pse_path = glob.glob(config.output_dir + '/PSENet_*.pth')
                    for p_path in pse_path:
                        if os.path.exists(p_path):
                            os.remove(p_path)

                writer.add_scalar(tag='Test/recall',
                                  scalar_value=recall,
                                  global_step=epoch)
                writer.add_scalar(tag='Test/precision',
                                  scalar_value=precision,
                                  global_step=epoch)
                writer.add_scalar(tag='Test/f1',
                                  scalar_value=f1,
                                  global_step=epoch)
        writer.close()
    except KeyboardInterrupt:
        save_checkpoint('{}/final.pth'.format(config.output_dir), model,
                        optimizer, epoch, logger)
    finally:
        if best_model['models']:
            logger.info(best_model)
Exemple #7
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    t_checkpoints = cfg.KD.TEACHER  #注意是在student配置文件中修改
    train_type = cfg.KD.TRAIN_TYPE  #注意是在student配置文件中修改
    train_type = get_train_type(train_type, t_checkpoints)
    logger.info('=> train type is {} '.format(train_type))

    if train_type == 'FPD':
        cfg_name = 'student_' + os.path.basename(args.cfg).split('.')[0]
    else:
        cfg_name = os.path.basename(args.cfg).split('.')[0]
    save_yaml_file(cfg_name, cfg, final_output_dir)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # fpd method, default NORMAL
    if train_type == 'FPD':
        tcfg = cfg.clone()
        tcfg.defrost()
        tcfg.merge_from_file(args.tcfg)
        tcfg.freeze()
        tcfg_name = 'teacher_' + os.path.basename(args.tcfg).split('.')[0]
        save_yaml_file(tcfg_name, tcfg, final_output_dir)
        # teacher model
        tmodel = eval('models.' + tcfg.MODEL.NAME + '.get_pose_net')(
            tcfg, is_train=False)

        load_checkpoint(t_checkpoints,
                        tmodel,
                        strict=True,
                        model_info='teacher_' + tcfg.MODEL.NAME)

        tmodel = torch.nn.DataParallel(tmodel, device_ids=cfg.GPUS).cuda()
        # define kd_pose loss function (criterion) and optimizer
        kd_pose_criterion = JointsMSELoss(
            use_target_weight=tcfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    if cfg.TRAIN.CHECKPOINT:
        load_checkpoint(cfg.TRAIN.CHECKPOINT,
                        model,
                        strict=True,
                        model_info='student_' + cfg.MODEL.NAME)
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # you can choose or replace pose_loss and kd_pose_loss type, including mse,kl,ohkm loss ect
    # define pose loss function (criterion) and optimizer
    pose_criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, tmodel, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)
    validate(cfg, valid_loader, valid_dataset, model, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # fpd method, default NORMAL
        if train_type == 'FPD':
            # train for one epoch
            fpd_train(cfg, train_loader, model, tmodel, pose_criterion,
                      kd_pose_criterion, optimizer, epoch, final_output_dir,
                      tb_log_dir, writer_dict)
        else:
            # train for one epoch
            train(cfg, train_loader, model, pose_criterion, optimizer, epoch,
                  final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  pose_criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Exemple #8
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch CIFAR10 Training with sparse masks')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--lr_decay',
                        default=[150, 250],
                        nargs='+',
                        type=int,
                        help='learning rate decay epochs')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('--batchsize',
                        default=256,
                        type=int,
                        help='batch size')
    parser.add_argument('--epochs',
                        default=350,
                        type=int,
                        help='number of epochs')
    parser.add_argument('--model',
                        type=str,
                        default='resnet32',
                        help='network model name')

    # parser.add_argument('--resnet_n', default=5, type=int, help='number of layers per resnet stage (5 for Resnet-32)')
    parser.add_argument(
        '--budget',
        default=-1,
        type=float,
        help='computational budget (between 0 and 1) (-1 for no sparsity)')
    parser.add_argument('-s',
                        '--save_dir',
                        type=str,
                        default='',
                        help='directory to save model')
    parser.add_argument('-r',
                        '--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        action='store_true',
                        help='evaluation mode')
    parser.add_argument('--plot_ponder',
                        action='store_true',
                        help='plot ponder cost')
    parser.add_argument('--workers',
                        default=8,
                        type=int,
                        help='number of dataloader workers')
    parser.add_argument('--pretrained',
                        action='store_true',
                        help='initialize with pretrained model')
    args = parser.parse_args()
    print('Args:', args)

    mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    ## DATA
    trainset = datasets.CIFAR10(root='../data',
                                train=True,
                                download=True,
                                transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False)

    valset = datasets.CIFAR10(root='../data',
                              train=False,
                              download=True,
                              transform=transform_test)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.batchsize,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=False)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    ## MODEL
    net_module = models.__dict__[args.model]
    model = net_module(sparse=args.budget >= 0,
                       pretrained=args.pretrained).to(device=device)

    ## CRITERION
    class Loss(nn.Module):
        def __init__(self):
            super(Loss, self).__init__()
            self.task_loss = nn.CrossEntropyLoss().to(device=device)
            self.sparsity_loss = dynconv.SparsityCriterion(
                args.budget, args.epochs) if args.budget >= 0 else None

        def forward(self, output, target, meta):
            l = self.task_loss(output, target)
            logger.add('loss_task', l.item())
            if self.sparsity_loss is not None:
                l += 10 * self.sparsity_loss(meta)
            return l

    criterion = Loss()

    ## OPTIMIZER
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    ## CHECKPOINT
    start_epoch = -1
    best_prec1 = 0

    if not args.evaluate and len(args.save_dir) > 0:
        if not os.path.exists(os.path.join(args.save_dir)):
            os.makedirs(os.path.join(args.save_dir))

    if args.resume:
        resume_path = args.resume
        if not os.path.isfile(resume_path):
            resume_path = os.path.join(resume_path, 'checkpoint.pth')
        if os.path.isfile(resume_path):
            print(f"=> loading checkpoint '{resume_path}'")
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(
                f"=> loaded checkpoint '{resume_path}'' (epoch {checkpoint['epoch']}, best prec1 {checkpoint['best_prec1']})"
            )
        else:
            msg = "=> no checkpoint found at '{}'".format(resume_path)
            if args.evaluate:
                raise ValueError(msg)
            else:
                print(msg)

    try:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_decay, last_epoch=start_epoch)
    except:
        print('Warning: Could not reload learning rate scheduler')
    start_epoch += 1

    ## Count number of params
    print("* Number of trainable parameters:", utils.count_parameters(model))

    ## EVALUATION
    if args.evaluate:
        print(f"########## Evaluation ##########")
        prec1 = validate(args, val_loader, model, criterion, start_epoch)
        return

    ## TRAINING
    for epoch in range(start_epoch, args.epochs):
        print(f"########## Epoch {epoch} ##########")

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(args, train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(args, val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        utils.save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1,
                'best_prec1': best_prec1,
            },
            folder=args.save_dir,
            is_best=is_best)

        print(f" * Best prec1: {best_prec1}")
Exemple #9
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    if args.local_rank != -1:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
    best_acc = 0

    args.print = args.gpu == 0
    # suppress printing if not master
    if (args.multiprocessing_distributed and args.gpu != 0) or\
       (args.local_rank != -1 and args.gpu != 0):

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        if args.local_rank != -1:
            args.rank = args.local_rank
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    ### model ###
    print("=> creating {} model with '{}' backbone".format(
        args.model, args.net))
    if args.model == 'coclr':
        model = CoCLR(args.net,
                      args.moco_dim,
                      args.moco_k,
                      args.moco_m,
                      args.moco_t,
                      topk=args.topk,
                      reverse=args.reverse)
        if args.reverse:
            print('[Warning] using RGB-Mining to help flow')
        else:
            print('[Warning] using Flow-Mining to help RGB')
    else:
        raise NotImplementedError
    args.num_seq = 2
    print('Re-write num_seq to %d' % args.num_seq)

    args.img_path, args.model_path, args.exp_path = set_path(args)

    # print(model)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
            model_without_ddp = model.module
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
            model_without_ddp = model.module
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        # raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")

    ### optimizer ###
    params = []
    if args.train_what == 'all':
        for name, param in model.named_parameters():
            params.append({'params': param})
    else:
        raise NotImplementedError('train_what invalid')

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    args.iteration = 1

    ### data ###
    transform_train = get_transform('train', args)
    train_loader = get_dataloader(get_data(transform_train, 'train', args),
                                  'train', args)
    transform_train_cuda = transforms.Compose([
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    channel=1)
    ])
    n_data = len(train_loader.dataset)

    print('===================================')

    lr_scheduler = None

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch'] + 1
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            state_dict = checkpoint['state_dict']

            try:
                model_without_ddp.load_state_dict(state_dict)
            except:
                print('[WARNING] Non-Equal load for resuming training!')
                neq_load_customized(model_without_ddp,
                                    state_dict,
                                    verbose=True)

            print("=> load resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                print('[WARNING] Not loading optimizer states')
        else:
            print("[Warning] no checkpoint found at '{}', use random init".
                  format(args.resume))

    elif args.pretrain != ['random', 'random']:
        # first path: weights to be trained
        # second path: weights as the oracle, not trained
        if os.path.isfile(
                args.pretrain[1]):  # second network --> load as sampler
            checkpoint = torch.load(args.pretrain[1],
                                    map_location=torch.device('cpu'))
            second_dict = checkpoint['state_dict']
            new_dict = {}
            for k, v in second_dict.items():  # only take the encoder_q
                if 'encoder_q.' in k:
                    k = k.replace('encoder_q.', 'sampler.')
                    new_dict[k] = v
            second_dict = new_dict

            new_dict = {}  # remove queue, queue_ptr
            for k, v in second_dict.items():
                if 'queue' not in k:
                    new_dict[k] = v
            second_dict = new_dict
            print("=> Use Oracle checkpoint '{}' (epoch {})".format(
                args.pretrain[1], checkpoint['epoch']))
        else:
            print("=> NO Oracle checkpoint found at '{}', use random init".
                  format(args.pretrain[1]))
            second_dict = {}

        if os.path.isfile(
                args.pretrain[0]):  # first network --> load both encoder q & k
            checkpoint = torch.load(args.pretrain[0],
                                    map_location=torch.device('cpu'))
            first_dict = checkpoint['state_dict']

            new_dict = {}  # remove queue, queue_ptr
            for k, v in first_dict.items():
                if 'queue' not in k:
                    new_dict[k] = v
            first_dict = new_dict

            # update both q and k with q
            new_dict = {}
            for k, v in first_dict.items():  # only take the encoder_q
                if 'encoder_q.' in k:
                    new_dict[k] = v
                    k = k.replace('encoder_q.', 'encoder_k.')
                    new_dict[k] = v
            first_dict = new_dict

            print("=> Use Training checkpoint '{}' (epoch {})".format(
                args.pretrain[0], checkpoint['epoch']))
        else:
            print("=> NO Training checkpoint found at '{}', use random init".
                  format(args.pretrain[0]))
            first_dict = {}

        state_dict = {**first_dict, **second_dict}
        try:
            del state_dict['queue_label']  # always re-fill the queue
        except:
            pass
        neq_load_customized(model_without_ddp, state_dict, verbose=True)

    else:
        print("=> train from scratch")

    torch.backends.cudnn.benchmark = True

    # tensorboard plot tools
    writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))
    args.train_plotter = TB.PlotterThread(writer_train)

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        np.random.seed(epoch)
        random.seed(epoch)

        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        _, train_acc = train_one_epoch(train_loader, model, criterion,
                                       optimizer, transform_train_cuda, epoch,
                                       args)
        if (epoch % args.save_freq == 0) or (epoch == args.epochs - 1):
            # save check_point on rank==0 worker
            if (not args.multiprocessing_distributed and args.rank == 0) \
                or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                is_best = train_acc > best_acc
                best_acc = max(train_acc, best_acc)
                state_dict = model_without_ddp.state_dict()
                save_dict = {
                    'epoch': epoch,
                    'state_dict': state_dict,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'iteration': args.iteration
                }
                save_checkpoint(save_dict,
                                is_best,
                                gap=args.save_freq,
                                filename=os.path.join(
                                    args.model_path,
                                    'epoch%d.pth.tar' % epoch),
                                keep_all='k400' in args.dataset)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
    sys.exit(0)
Exemple #10
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(config)

    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model, dump_input))

    # copy model file
    this_dir = os.path.dirname(__file__)
    models_dst_dir = os.path.join(final_output_dir, 'models')
    if os.path.exists(models_dst_dir):
        shutil.rmtree(models_dst_dir)
    shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = get_optimizer(config, model)

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
            best_model = True

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # Data loading code
    traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET)
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    #增加的代码。
    #print(train_dataset.classes)  #根据分的文件夹的名字来确定的类别
    with open("class.txt", "w") as f1:
        for classname in train_dataset.classes:
            f1.write(classname + "\n")

    #print(train_dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
    with open("classToIndex.txt", "w") as f2:
        for key, value in train_dataset.class_to_idx.items():
            f2.write(str(key) + " " + str(value) + '\n')

    #print(train_dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True)

    valid_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
                transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
                transforms.ToTensor(),
                normalize,
            ])),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, model, criterion,
                                  final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': config.MODEL.NAME,
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
            filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)

    final_pth_file = os.path.join(final_output_dir, 'HRNet.pt')
    print("final_pth_file:", final_pth_file)
    torch.save(model.module, final_pth_file)
    writer_dict['writer'].close()
Exemple #11
0
def main():
    global args, best_mIoU
    PID = os.getpid()
    args = parser.parse_args()
    prepare_seed(args.rand_seed)
    device = torch.device("cuda:" + str(args.gpus))

    if args.timestamp == 'none':
        args.timestamp = "{:}".format(
            time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))

    switch_model = args.switch_model
    assert switch_model in ["deeplab50", "deeplab101"]

    # Log outputs
    if args.evaluate:
        args.save_dir = args.save_dir + "/GTA5-%s-evaluate"%switch_model + \
            "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
    else:
        args.save_dir = args.save_dir + \
            "/GTA5_512x512-{model}-LWF.stg{csg_stages}.w{csg_weight}-APool.{apool}-Aug.{augment}-chunk{chunks}-mlp{mlp}.K{csg_k}-LR{lr}.bone{factor}-epoch{epochs}-batch{batch_size}-seed{seed}".format(
                    model=switch_model,
                    csg_stages=args.csg_stages,
                    mlp=args.mlp,
                    csg_weight=args.csg,
                    apool=args.apool,
                    augment=args.augment,
                    chunks=args.chunks,
                    csg_k=args.csg_k,
                    lr="%.2E"%args.lr,
                    factor="%.1f"%args.factor,
                    epochs=args.epochs,
                    batch_size=args.batch_size,
                    seed=args.rand_seed
                    ) + \
            "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
    logger = prepare_logger(args)

    from config_seg import config as data_setting
    data_setting.batch_size = args.batch_size
    train_loader = get_train_loader(data_setting,
                                    GTA5,
                                    test=False,
                                    augment=args.augment)

    args.stages = [int(stage) for stage in args.csg_stages.split('.')
                   ] if len(args.csg_stages) > 0 else []
    chunks = [int(chunk) for chunk in args.chunks.split('.')
              ] if len(args.chunks) > 0 else []
    assert len(chunks) == 1 or len(chunks) == len(args.stages)
    if len(chunks) < len(args.stages):
        chunks = [chunks[0]] * len(args.stages)

    if switch_model == 'deeplab50':
        layers = [3, 4, 6, 3]
    elif switch_model == 'deeplab101':
        layers = [3, 4, 23, 3]
    model = csg_builder.CSG(deeplab,
                            get_head=None,
                            K=args.csg_k,
                            stages=args.stages,
                            chunks=chunks,
                            task='new-seg',
                            apool=args.apool,
                            mlp=args.mlp,
                            base_encoder_kwargs={
                                'num_seg_classes': args.num_classes,
                                'layers': layers
                            })

    threds = 3
    evaluator = SegEvaluator(
        Cityscapes(data_setting, 'val', None),
        args.num_classes,
        np.array([0.485, 0.456, 0.406]),
        np.array([0.229, 0.224, 0.225]),
        model.encoder_q, [
            1,
        ],
        False,
        devices=args.gpus,
        config=data_setting,
        threds=threds,
        verbose=False,
        save_path=None,
        show_image=False
    )  # just calculate mIoU, no prediction file is generated
    # verbose=False, save_path="./prediction_files", show_image=True, show_prediction=True)  # generate prediction files

    # Setup optimizer
    factor = args.factor
    sgd_in = [
        {
            'params': get_params(model.encoder_q, ["conv1"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["bn1"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer1"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer2"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer3"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer4"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["fc_new"]),
            'lr': args.lr
        },
    ]
    base_lrs = [group['lr'] for group in sgd_in]
    optimizer = SGD(sgd_in,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

    # Optionally resume from a checkpoint
    if args.resume != 'none':
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=lambda storage, loc: storage)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            msg = model.load_state_dict(checkpoint['state_dict'])
            print("resume weights: ", msg)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=ImageClassdata> no checkpoint found at '{}'".format(
                args.resume))

    model = model.to(device)

    if args.evaluate:
        mIoU = validate(evaluator, model, -1)
        print(mIoU)
        exit(0)

    # Main training loop
    iter_max = args.epochs * len(train_loader)
    iter_stat = IterNums(iter_max)
    for epoch in range(args.start_epoch, args.epochs):
        print("<< ============== JOB (PID = %d) %s ============== >>" %
              (PID, args.save_dir))
        logger.log("Epoch: %d" % (epoch + 1))
        # train for one epoch
        train(args,
              train_loader,
              model,
              optimizer,
              base_lrs,
              iter_stat,
              epoch,
              logger,
              device,
              adjust_lr=epoch < args.epochs)

        # evaluate on validation set
        torch.cuda.empty_cache()
        mIoU = validate(evaluator, model, epoch)
        logger.writer.add_scalar("mIoU", mIoU, epoch + 1)
        logger.log("mIoU: %f" % mIoU)

        # remember best mIoU and save checkpoint
        is_best = mIoU > best_mIoU
        best_mIoU = max(mIoU, best_mIoU)
        save_checkpoint(
            args.save_dir, {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mIoU': best_mIoU,
            }, is_best)

    logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU))
Exemple #12
0
def setup_and_run(args, criterion, device, train_loader, test_loader,
                  val_loader, logging, results):
    global BEST_ACC
    print('\n#### Running REF ####')

    # architecture
    if args.architecture == 'MLP':
        model = models.MLP(args.input_dim, args.hidden_dim,
                           args.output_dim).to(device)
    elif args.architecture == 'LENET300':
        model = models.LeNet300(args.input_dim, args.output_dim).to(device)
    elif args.architecture == 'LENET5':
        model = models.LeNet5(args.input_channels, args.im_size,
                              args.output_dim).to(device)
    elif 'VGG' in args.architecture:
        assert (args.architecture == 'VGG11' or args.architecture == 'VGG13'
                or args.architecture == 'VGG16'
                or args.architecture == 'VGG19')
        model = models.VGG(args.architecture, args.input_channels,
                           args.im_size, args.output_dim).to(device)
    elif args.architecture == 'RESNET18':
        model = models.ResNet18(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == 'RESNET34':
        model = models.ResNet34(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == 'RESNET50':
        model = models.ResNet50(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == 'RESNET101':
        model = models.ResNet101(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    elif args.architecture == 'RESNET152':
        model = models.ResNet152(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    else:
        print 'Architecture type "{0}" not recognized, exiting ...'.format(
            args.architecture)
        exit()

    # optimizer
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              nesterov=args.nesterov,
                              weight_decay=args.weight_decay)
    else:
        print 'Optimizer type "{0}" not recognized, exiting ...'.format(
            args.optimizer)
        exit()

    # lr-scheduler
    if args.lr_decay == 'STEP':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=args.lr_scale)
    elif args.lr_decay == 'EXP':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.lr_scale)
    elif args.lr_decay == 'MSTEP':
        x = args.lr_interval.split(',')
        lri = [int(v) for v in x]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=lri,
                                                   gamma=args.lr_scale)
        args.lr_interval = 1  # lr_interval handled in scheduler!
    else:
        print 'LR decay type "{0}" not recognized, exiting ...'.format(
            args.lr_decay)
        exit()

    init_weights(model, xavier=True)
    logging.info(model)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("Number of parameters: %d", num_parameters)

    start_epoch = -1
    iters = 0  # total no of iterations, used to do many things!
    # optionally resume from a checkpoint
    if args.eval:
        logging.info('Loading checkpoint file "{0}" for evaluation'.format(
            args.eval))
        if not os.path.isfile(args.eval):
            print 'Checkpoint file "{0}" for evaluation not recognized, exiting ...'.format(
                args.eval)
            exit()
        checkpoint = torch.load(args.eval)
        model.load_state_dict(checkpoint['state_dict'])

    elif args.resume:
        checkpoint_file = args.resume
        logging.info('Loading checkpoint file "{0}" to resume'.format(
            args.resume))
        if not os.path.isfile(checkpoint_file):
            print 'Checkpoint file "{0}" not recognized, exiting ...'.format(
                checkpoint_file)
            exit()
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        assert (args.architecture == checkpoint['architecture'])
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        BEST_ACC = checkpoint['best_acc1']
        iters = checkpoint['iters']
        logging.debug('best_acc1: {0}, iters: {1}'.format(BEST_ACC, iters))

    if not args.eval:
        logging.info('Training...')
        model.train()
        st = timer()

        for e in range(start_epoch + 1, args.num_epochs):
            for i, (data, target) in enumerate(train_loader):
                l = train_step(model, device, data, target, optimizer,
                               criterion)
                if i % args.log_interval == 0:
                    acc1, acc5 = evaluate(args,
                                          model,
                                          device,
                                          val_loader,
                                          training=True)
                    logging.info(
                        'Epoch: {0},\t Iter: {1},\t Loss: {loss:.5f},\t Val-Acc1: {acc1:.2f} '
                        '(Best: {best:.2f}),\t Val-Acc5: {acc5:.2f}'.format(
                            e, i, loss=l, acc1=acc1, best=BEST_ACC, acc5=acc5))

                if iters % args.lr_interval == 0:
                    lr = args.learning_rate
                    for param_group in optimizer.param_groups:
                        lr = param_group['lr']
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        if lr != param_group['lr']:
                            logging.info('lr: {0}'.format(
                                param_group['lr']))  # print if changed
                iters += 1

            # save checkpoint
            acc1, acc5 = evaluate(args,
                                  model,
                                  device,
                                  val_loader,
                                  training=True)
            results.add(epoch=e,
                        iteration=i,
                        train_loss=l,
                        val_acc1=acc1,
                        best_val_acc1=BEST_ACC)
            util.save_checkpoint(
                {
                    'epoch': e,
                    'architecture': args.architecture,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'best_acc1': BEST_ACC,
                    'iters': iters
                },
                is_best=False,
                path=args.save_dir)
            results.save()

        et = timer()
        logging.info('Elapsed time: {0} seconds'.format(et - st))

        acc1, acc5 = evaluate(args, model, device, val_loader, training=True)
        logging.info(
            'End of training, Val-Acc: {acc1:.2f} (Best: {best:.2f}), Val-Acc5: {acc5:.2f}'
            .format(acc1=acc1, best=BEST_ACC, acc5=acc5))
        # load saved model
        saved_model = torch.load(args.save_name)
        model.load_state_dict(saved_model['state_dict'])
    # end of training

    # eval-set
    if args.eval_set != 'TRAIN' and args.eval_set != 'TEST':
        print 'Evaluation set "{0}" not recognized ...'.format(args.eval_set)

    logging.info('Evaluating REF on the {0} set...'.format(args.eval_set))
    st = timer()
    if args.eval_set == 'TRAIN':
        acc1, acc5 = evaluate(args, model, device, train_loader)
    else:
        acc1, acc5 = evaluate(args, model, device, test_loader)
    et = timer()
    logging.info('Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%'.format(
        acc1=acc1, acc5=acc5))
    logging.info('Elapsed time: {0} seconds'.format(et - st))
        }

        # save state
        log_epoch(
            logger,
            epoch,
            train_loss,
            val_loss,
            opti.param_groups[0]["lr"],
            batch_train,
            batch_val,
            data_train,
            data_val,
            recall,
        )
        save_checkpoint(state, is_best, args.name, epoch)

        # Optimizing the text pipeline after one epoch
        if epoch == 1:
            print("Start training text pooler")
            if args.bert > 0:
                opti.add_param_group({
                    "params":
                    filter(
                        lambda p: p.requires_grad,
                        join_emb.module.cap_emb.model.pooler.parameters(),
                    ),
                    "lr":
                    0.00001
                })
Exemple #14
0
def train_and_evaluate(model,
                       train_dataloader,
                       val_dataloader,
                       optimizer,
                       loss_fn,
                       metrics,
                       params,
                       model_dir,
                       restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)
        #
        val_acc = val_metrics['accuracy']
        is_best = val_acc >= best_val_acc

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir,
                                          "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir,
                                      "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
Exemple #15
0
def main():
    global best_iou
    global best_dice
    # model
    model = smp.Unet(encoder_name=configs.encoder,
                     encoder_weights=configs.encoder_weights,
                     classes=configs.num_classes,
                     activation=configs.activation)
    if len(configs.gpu_id) > 1:
        model = nn.DataParallel(model)
    model.cuda()
    # get files
    filenames = glob(configs.dataset + "masks/*")
    filenames = [os.path.basename(i) for i in filenames]
    # random split dataset into train and val
    train_files, val_files = train_test_split(filenames, test_size=0.2)
    # define different aug
    if configs.use_strong_aug:
        transform_train = stong_aug()
    else:
        transform_train = get_training_augmentation()
    transform_valid = get_valid_augmentation()
    # make data loader for train and val
    train_dataset = SegDataset(train_files,
                               phase="train",
                               transforms=transform_train)
    valid_dataset = SegDataset(val_files,
                               phase="valid",
                               transforms=transform_valid)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=configs.bs,
                                               shuffle=True,
                                               num_workers=configs.workers)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=configs.bs,
                                               shuffle=False,
                                               num_workers=configs.workers)
    optimizer = get_optimizer(model)
    loss_func = get_loss_func(configs.loss_func)
    criterion = loss_func().cuda()
    # tensorboardX writer
    writer = SummaryWriter(configs.log_dir)
    # set lr scheduler method
    if configs.lr_scheduler == "step":
        scheduler_default = torch.optim.lr_scheduler.StepLR(optimizer,
                                                            step_size=10,
                                                            gamma=0.1)
    elif configs.lr_scheduler == "on_loss":
        scheduler_default = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.2, patience=5, verbose=False)
    elif configs.lr_scheduler == "on_iou":
        scheduler_default = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.2, patience=5, verbose=False)
    elif configs.lr_scheduler == "on_dice":
        scheduler_default = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.2, patience=5, verbose=False)
    elif configs.lr_scheduler == "cosine":
        scheduler_default = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, configs.epochs - configs.warmup_epo)
    else:
        scheduler_default = torch.optim.lr_scheduler.StepLR(optimizer,
                                                            step_size=6,
                                                            gamma=0.1)
    # scheduler with warmup
    if configs.warmup:
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=configs.warmup_factor,
                                           total_epoch=configs.warmup_epo,
                                           after_scheduler=scheduler_default)
    else:
        scheduler = scheduler_default
    for epoch in range(configs.epochs):
        print('\nEpoch: [%d | %d] LR: %.8f' %
              (epoch + 1, configs.epochs, optimizer.param_groups[0]['lr']))
        train_loss, train_dice, train_iou = train(train_loader, model,
                                                  criterion, optimizer, epoch,
                                                  writer)
        valid_loss, valid_dice, valid_iou = eval(valid_loader, model,
                                                 criterion, epoch, writer)
        if configs.lr_scheduler == "step" or configs.lr_scheduler == "cosine" or configs.warmup:
            scheduler.step(epoch)
        elif configs.lr_scheduler == "on_iou":
            scheduler.step(valid_iou)
        elif configs.lr_scheduler == "on_dice":
            scheduler.step(valid_dice)
        elif configs.lr_scheduler == "on_loss":
            scheduler.step(valid_loss)
        # save model
        is_best_iou = valid_iou > best_iou
        is_best_dice = valid_dice > best_dice
        best_iou = max(valid_iou, best_iou)
        best_dice = max(valid_dice, best_dice)
        print("Best {}: {} ,Best Dice: {}".format(configs.metric, best_iou,
                                                  best_dice))
        save_checkpoint({
            'state_dict': model.state_dict(),
        }, is_best_iou, is_best_dice)
def main():
    torch.manual_seed(42)
    args = parse_args()
    start_time = str(datetime.datetime.now())
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
    end_time = str(datetime.datetime.now())
    file = open("train_time.txt", "w")
    file.write(start_time)
    file.write(end_time)
    file.close()
Exemple #17
0
def main(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device('cuda')
    num_gpu = len(str(args.gpu).split(','))
    args.batch_size = num_gpu * args.batch_size

    if args.dataset == 'ucf101': args.num_class = 101
    elif args.dataset == 'hmdb51': args.num_class = 51
    elif args.dataset == 'CATER_actions_present': args.num_class = 14
    elif args.dataset == 'CATER_actions_order_uniq': args.num_class = 301

    ### classifier model ###
    if args.model == 'lc':
        model = LC(sample_size=args.img_dim,
                   num_seq=args.num_seq,
                   seq_len=args.seq_len,
                   network=args.net,
                   num_class=args.num_class,
                   dropout=args.dropout,
                   train_what=args.train_what)
    elif args.model == 'timecycle':
        model = CycleTime(sample_size=args.img_dim,
                          num_seq=args.num_seq,
                          seq_len=args.seq_len,
                          num_class=args.num_class,
                          dropout=args.dropout,
                          train_what=args.train_what)
    else:
        raise ValueError('wrong model!')

    model.to(device)
    model = nn.DataParallel(model)
    model_without_dp = model.module
    if args.dataset.split('_')[0] == 'CATER':
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    ### optimizer ###
    params = None
    if args.train_what == 'ft':
        print('=> finetune backbone with smaller lr')
        params = []
        for name, param in model.module.named_parameters():
            if ('resnet' in name) or ('rnn' in name):
                params.append({'params': param, 'lr': args.lr / 10})
            else:
                params.append({'params': param})
    elif args.train_what == 'last':
        print('=> train only last layer')
        params = []
        for name, param in model.named_parameters():
            if ('bone' in name) or ('agg' in name) or ('mb' in name) or (
                    'network_pred' in name):
                param.requires_grad = False
            else:
                params.append({'params': param})
    else:
        pass  # train all layers

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    if params is None: params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    ### scheduler ###
    if args.dataset == 'hmdb51':
        step = args.schedule
        if step == []: step = [150, 250]
        lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(
            ep, gamma=0.1, step=step, repeat=1)
    elif args.dataset == 'ucf101':
        step = args.schedule
        if step == []: step = [300, 400]
        lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(
            ep, gamma=0.1, step=step, repeat=1)
    elif args.dataset.split('_')[0] == 'CATER':
        step = args.schedule
        if step == []: step = [150, 250]
        lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(
            ep, gamma=0.1, step=step, repeat=1)

    lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    print('=> Using scheduler at {} epochs'.format(step))

    args.old_lr = None
    best_acc = 0
    args.iteration = 1

    ### if in test mode ###
    if args.test:
        if os.path.isfile(args.test):
            print("=> loading test checkpoint '{}'".format(args.test))
            checkpoint = torch.load(args.test,
                                    map_location=torch.device('cpu'))
            try:
                model_without_dp.load_state_dict(checkpoint['state_dict'])
            except:
                print(
                    '=> [Warning]: weight structure is not equal to test model; Load anyway =='
                )
                model_without_dp = neq_load_customized(
                    model_without_dp, checkpoint['state_dict'])
            epoch = checkpoint['epoch']
            print("=> loaded testing checkpoint '{}' (epoch {})".format(
                args.test, checkpoint['epoch']))
        elif args.test == 'random':
            epoch = 0
            print("=> loaded random weights")
        else:
            print("=> no checkpoint found at '{}'".format(args.test))
            sys.exit(0)

        args.logger = Logger(path=os.path.dirname(args.test))
        _, test_dataset = get_data(None, 'test')
        test_loss, test_acc = test(test_dataset, model, criterion, device,
                                   epoch, args)
        sys.exit()

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            model_without_dp.load_state_dict(checkpoint['state_dict'])
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                print('[WARNING] Not loading optimizer states')
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            sys.exit(0)

    if (not args.resume) and args.pretrain:
        if args.pretrain == 'random':
            print('=> using random weights')
        elif os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(
                args.pretrain))
            checkpoint = torch.load(args.pretrain,
                                    map_location=torch.device('cpu'))
            model_without_dp = neq_load_customized(model_without_dp,
                                                   checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))
            sys.exit(0)

    ### data ###
    transform = transforms.Compose([
        A.RandomSizedCrop(consistent=True, size=224, p=1.0),
        A.Scale(size=(args.img_dim, args.img_dim)),
        A.RandomHorizontalFlip(consistent=True),
        A.ColorJitter(brightness=0.5,
                      contrast=0.5,
                      saturation=0.5,
                      hue=0.25,
                      p=0.3,
                      consistent=True),
        A.ToTensor(),
        A.Normalize()
    ])
    val_transform = transforms.Compose([
        A.RandomSizedCrop(consistent=True, size=224, p=0.3),
        A.Scale(size=(args.img_dim, args.img_dim)),
        A.RandomHorizontalFlip(consistent=True),
        A.ColorJitter(brightness=0.2,
                      contrast=0.2,
                      saturation=0.2,
                      hue=0.1,
                      p=0.3,
                      consistent=True),
        A.ToTensor(),
        A.Normalize()
    ])

    train_loader, _ = get_data(transform, 'train')
    val_loader, _ = get_data(val_transform, 'val')

    # setup tools
    args.img_path, args.model_path = set_path(args)
    args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val'))
    args.writer_train = SummaryWriter(
        logdir=os.path.join(args.img_path, 'train'))
    torch.backends.cudnn.benchmark = True

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_acc = train_one_epoch(train_loader, model, criterion,
                                                optimizer, device, epoch, args)
        val_loss, val_acc = validate(val_loader, model, criterion, device,
                                     epoch, args)
        lr_scheduler.step(epoch)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_dict = {
            'epoch': epoch,
            'backbone': args.net,
            'state_dict': model_without_dp.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'iteration': args.iteration
        }
        save_checkpoint(save_dict,
                        is_best,
                        filename=os.path.join(args.model_path,
                                              'epoch%s.pth.tar' % str(epoch)),
                        keep_all=False)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
    sys.exit(0)
                                       step)
                board_valid.add_scalar('D/loss_D_real',
                                       loss_D_real_total.item() / n_valid_loop,
                                       step)
                board_valid.add_scalar('D/loss_D_fake',
                                       loss_D_fake_total.item() / n_valid_loop,
                                       step)

            step += 1
            n_print -= 1

        #====================================================
        # モデルの保存
        #====================================================
        if (epoch % args.n_save_epoches == 0):
            save_checkpoint(
                model_G, device,
                os.path.join(args.save_checkpoints_dir, args.exper_name,
                             'model_G_ep%03d.pth' % (epoch)))
            save_checkpoint(
                model_G, device,
                os.path.join(args.save_checkpoints_dir, args.exper_name,
                             'model_G_final.pth'))
            print("saved checkpoints")

    print("Finished Training Loop.")
    save_checkpoint(
        model_G, device,
        os.path.join(args.save_checkpoints_dir, args.exper_name,
                     'model_G_final.pth'))
Exemple #19
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_p, model_d = eval('models.' + cfg.MODEL.NAME +
                            '.get_adaptive_pose_net')(cfg, is_train=True)

    if cfg.TRAIN.CHECKPOINT:
        logger.info('=> loading model from {}'.format(cfg.TRAIN.CHECKPOINT))
        model_p.load_state_dict(torch.load(cfg.TRAIN.CHECKPOINT))

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'pre_train_global_steps': 0,
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model_p, (dump_input, ), verbose=False)

    logger.info(get_model_summary(model_p, dump_input))

    model_p = torch.nn.DataParallel(model_p, device_ids=cfg.GPUS).cuda()
    model_d = torch.nn.DataParallel(model_d, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer for pose_net
    criterion_p = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer_p = get_optimizer(cfg, model_p)

    # define loss function (criterion) and optimizer for domain
    criterion_d = torch.nn.BCEWithLogitsLoss().cuda()
    optimizer_d = get_optimizer(cfg, model_d)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_pre_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_PRE_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_pre_loader = torch.utils.data.DataLoader(
        train_pre_dataset,
        batch_size=cfg.TRAIN.PRE_BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    syn_labels = train_dataset._load_syrip_syn_annotations()
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=BalancedBatchSampler(train_dataset, syn_labels),
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model_p.load_state_dict(checkpoint['state_dict'])

        optimizer_p.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler_p = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_p,
        cfg.TRAIN.LR_STEP,
        cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch)

    lr_scheduler_d = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_d, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR)

    epoch_D = cfg.TRAIN.PRE_EPOCH
    losses_D_list = []
    acces_D_list = []
    acc_num_total = 0
    num = 0
    losses_d = AverageMeter()

    # Pretrained Stage
    print('Pretrained Stage:')
    print('Start to train Domain Classifier-------')
    for epoch_d in range(epoch_D):  # epoch
        model_d.train()
        model_p.train()

        for i, (input, target, target_weight,
                meta) in enumerate(train_pre_loader):  # iteration
            # compute output for pose_net
            feature_outputs, outputs = model_p(input)
            #print(feature_outputs.size())
            # compute for domain classifier
            domain_logits = model_d(feature_outputs.detach())
            domain_label = (meta['synthetic'].unsqueeze(-1) *
                            1.0).cuda(non_blocking=True)
            # print(domain_label)

            loss_d = criterion_d(domain_logits, domain_label)
            loss_d.backward(retain_graph=True)
            optimizer_d.step()

            # compute accuracy of classifier
            acc_num = 0
            for j in range(len(domain_label)):
                if (domain_logits[j] > 0 and domain_label[j] == 1.0) or (
                        domain_logits[j] < 0 and domain_label[j] == 0.0):
                    acc_num += 1
                    acc_num_total += 1
                num += 1
            acc_d = acc_num * 1.0 / input.size(0)
            acces_D_list.append(acc_d)

            optimizer_d.zero_grad()
            losses_d.update(loss_d.item(), input.size(0))

            if i % cfg.PRINT_FREQ == 0:
                msg = 'Epoch: [{0}][{1}/{2}]\t' \
                      'Accuracy_d: {3} ({4})\t' \
                      'Loss_d: {loss_d.val:.5f} ({loss_d.avg:.5f})'.format(
                          epoch_d, i, len(train_pre_loader), acc_d, acc_num_total * 1.0 / num, loss_d = losses_d)
                logger.info(msg)

                writer = writer_dict['writer']
                pre_global_steps = writer_dict['pre_train_global_steps']
                writer.add_scalar('pre_train_loss_D', losses_d.val,
                                  pre_global_steps)
                writer.add_scalar('pre_train_acc_D', acc_d, pre_global_steps)
                writer_dict['pre_train_global_steps'] = pre_global_steps + 1

            losses_D_list.append(losses_d.val)

    print('Training Stage (Step I and II):')
    losses_P_list = []
    acces_P_list = []
    losses_p = AverageMeter()
    acces_p = AverageMeter()
    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler_p.step()

        # train for one epoch
        losses_P_list, losses_D_list, acces_P_list, acces_D_list = train_adaptive(
            cfg, train_loader, model_p, model_d, criterion_p, criterion_d,
            optimizer_p, optimizer_d, epoch, final_output_dir, tb_log_dir,
            writer_dict, losses_P_list, losses_D_list, acces_P_list,
            acces_D_list, acc_num_total, num, losses_p, acces_p, losses_d)

        # evaluate on validation set
        perf_indicator = validate_adaptive(cfg, valid_loader, valid_dataset,
                                           model_p, criterion_p,
                                           final_output_dir, tb_log_dir,
                                           writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model_p.state_dict(),
                'best_state_dict': model_p.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer_p.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model_p.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()

    np.save('./losses_D.npy', np.array(losses_D_list))  # Adversarial-D
    np.save('./losses_P.npy', np.array(losses_P_list))  # P
    np.save('./acces_P.npy', np.array(acces_P_list))  # P
    np.save('./acces_D.npy', np.array(acces_D_list))  # D
def main_worker(gpu, ngpus_per_node, args, final_output_dir, tb_log_dir):
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if cfg.MULTIPROCESSING_DISTRIBUTED:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        print('Init process group: dist_url: {}, world_size: {}, rank: {}'.
              format(args.dist_url, args.world_size, args.rank))
        dist.init_process_group(backend=cfg.DIST_BACKEND,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    update_config(cfg, args)

    # setup logger
    logger, _ = setup_logger(final_output_dir, args.rank, 'train')

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    if not cfg.MULTIPROCESSING_DISTRIBUTED or (cfg.MULTIPROCESSING_DISTRIBUTED
                                               and args.rank % ngpus_per_node
                                               == 0):
        this_dir = os.path.dirname(__file__)
        shutil.copy2(
            os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
            final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    if not cfg.MULTIPROCESSING_DISTRIBUTED or (cfg.MULTIPROCESSING_DISTRIBUTED
                                               and args.rank % ngpus_per_node
                                               == 0):
        dump_input = torch.rand(
            (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))
        # writer_dict['writer'].add_graph(model, (dump_input, ))
        # logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    if args.distributed:
        if cfg.MODEL.SYNC_BN:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            # args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    loss_factory = MultiLossFactory(cfg).cuda()

    # Data loading code
    train_loader = make_dataloader(cfg,
                                   is_train=True,
                                   distributed=args.distributed)
    logger.info(train_loader.dataset)

    best_perf = -1
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)

    if cfg.FP16.ENABLED:
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=cfg.FP16.STATIC_LOSS_SCALE,
            dynamic_loss_scale=cfg.FP16.DYNAMIC_LOSS_SCALE)

    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    if cfg.FP16.ENABLED:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer.optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        # train one epoch
        do_train(cfg,
                 model,
                 train_loader,
                 loss_factory,
                 optimizer,
                 epoch,
                 final_output_dir,
                 tb_log_dir,
                 writer_dict,
                 fp16=cfg.FP16.ENABLED)

        # In PyTorch 1.1.0 and later, you should call `lr_scheduler.step()` after `optimizer.step()`.
        lr_scheduler.step()

        perf_indicator = epoch
        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        if not cfg.MULTIPROCESSING_DISTRIBUTED or (
                cfg.MULTIPROCESSING_DISTRIBUTED and args.rank == 0):
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': cfg.MODEL.NAME,
                    'state_dict': model.state_dict(),
                    'best_state_dict': model.module.state_dict(),
                    'perf': perf_indicator,
                    'optimizer': optimizer.state_dict(),
                }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state{}.pth.tar'.format(gpu))

    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Exemple #21
0
def main():
    args.lr = 1e-4
    args.batch_size = 1
    args.momentum = 0.95
    args.decay = 5 * 1e-4
    args.gpu_id = "cuda:4"
    args.start_epoch = 0
    args.epochs = 1000
    args.bn = False
    args.workers = 2
    args.seed = time.time()
    args.print_freq = 400
    args.dataset = 'A'
    args.dataset_id = 1

    train_image_dir = '/home/rainkeeper/Projects/Datasets/shanghaiTech/processed_CSRNet_uncrop_data_gpu' + str(
        args.dataset_id) + '/part_' + args.dataset + '/train_image'
    train_gt_dir = '/home/rainkeeper/Projects/Datasets/shanghaiTech/processed_CSRNet_uncrop_data_gpu' + str(
        args.dataset_id) + '/part_' + args.dataset + '/train_gt_ADAPTIVE_0'
    val_image_dir = '/home/rainkeeper/Projects/Datasets/shanghaiTech/processed_CSRNet_uncrop_data_gpu' + str(
        args.dataset_id) + '/part_' + args.dataset + '/val_image'
    val_gt_dir = '/home/rainkeeper/Projects/Datasets/shanghaiTech/processed_CSRNet_uncrop_data_gpu' + str(
        args.dataset_id) + '/part_' + args.dataset + '/val_gt_ADAPTIVE_0'
    test_image_dir = '/home/rainkeeper/Projects/Datasets/shanghaiTech/processed_CSRNet_uncrop_data_gpu' + str(
        args.dataset_id) + '/part_' + args.dataset + '/test_image'
    test_gt_dir = '/home/rainkeeper/Projects/Datasets/shanghaiTech/processed_CSRNet_uncrop_data_gpu' + str(
        args.dataset_id) + '/part_' + args.dataset + '/test_gt_ADAPTIVE_0'

    args.device = torch.device(
        args.gpu_id if torch.cuda.is_available() else "cpu")
    torch.cuda.manual_seed(args.seed)

    model = net()
    model.to(args.device)

    criterion = nn.MSELoss(reduction='sum').to(args.device)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr)

    # if args.pre:
    if False:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = torch.load(args.pre)
            args.start_epoch = checkpoint['epoch']
            args.best_mae = checkpoint['best_mae']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pre, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))

    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        train_mae, train_mse, train_gt_sum, train_predict_sum = train(
            train_image_dir, train_gt_dir, model, criterion, optimizer, epoch)
        val_mae, val_mse, val_gt_sum, val_predict_sum = validate(
            val_image_dir, val_gt_dir, model)
        test_mae, test_mse, test_gt_sum, test_predict_sum = sssss(
            test_image_dir, test_gt_dir, model)

        is_best = (test_mae <= args.best_mae)
        if is_best:
            args.best_mae = test_mae
            args.best_mse = test_mse
            args.best_epoch = epoch

        print(
            'current train mae: %.6f, current train mse: %.6f, current_train_gt_sum: %.6f, current_train_predict_sum:%.6f'
            % (train_mae, train_mse, train_gt_sum, train_predict_sum))
        print(
            'current val mae: %.6f, current val mse: %.6f, current_gt_sum: %.6f, current_predict_sum:%.6f'
            % (val_mae, val_mse, val_gt_sum, val_predict_sum))
        print(
            'current test mae: %.6f, current test mse: %.6f, current_test_gt_sum: %.6f, current_test_predict_sum:%.6f'
            % (test_mae, test_mse, test_gt_sum, test_predict_sum))
        print('best test mae: %.6f, best test mse: %.6f' %
              (args.best_mae, args.best_mse))
        print('best epoch:%d' % args.best_epoch)
        print('\n')

        print(
            'current train mae: %.6f, current train mse: %.6f, current_train_gt_sum: %.6f, current_train_predict_sum:%.6f'
            % (train_mae, train_mse, train_gt_sum, train_predict_sum),
            file=terminal_file)
        print(
            'current val mae: %.6f, current val mse: %.6f, current_gt_sum: %.6f, current_predict_sum:%.6f'
            % (val_mae, val_mse, val_gt_sum, val_predict_sum),
            file=terminal_file)
        print(
            'current test mae: %.6f, current test mse: %.6f, current_test_gt_sum: %.6f, current_test_predict_sum:%.6f'
            % (test_mae, test_mse, test_gt_sum, test_predict_sum),
            file=terminal_file)
        print('best test mae: %.6f, best test mse: %.6f' %
              (args.best_mae, args.best_mse),
              file=terminal_file)
        print('best epoch:%d' % args.best_epoch, file=terminal_file)
        print('\n', file=terminal_file)

        if is_best:
            test_mae_2f = round(test_mae, 2)
            test_mse_2f = round(test_mse, 2)
            val_mae_2f = round(val_mae, 2)
            val_mse_2f = round(val_mse, 2)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.pre,
                    'state_dict': model.state_dict(),
                    'current_mae': test_mae,
                    'best_mae': args.best_mae,
                    'optimizer': optimizer.state_dict()
                }, checkpoint_save_dir, args.dataset, epoch, test_mae_2f,
                test_mse_2f, val_mae_2f, val_mse_2f)
Exemple #22
0
def train_val(model, args):

    train_dir = args.train_dir
    val_dir = args.val_dir

    config = Config(args.config)
    cudnn.benchmark = True

    #lspet dataset contains 10000 images, lsp dataset contains 2000 images.

    # train
    train_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
        'lspet', train_dir, 8,
        Mytransforms.Compose([
            Mytransforms.RandomResized(),
            Mytransforms.RandomRotate(40),
            Mytransforms.RandomCrop(368),
            Mytransforms.RandomHorizontalFlip(),
        ])),
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True)

    # val
    if args.val_dir is not None and config.test_interval != 0:
        # val
        val_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
            'lsp', val_dir, 8,
            Mytransforms.Compose([
                Mytransforms.TestResized(368),
            ])),
                                                 batch_size=config.batch_size,
                                                 shuffle=True,
                                                 num_workers=config.workers,
                                                 pin_memory=True)

    criterion = nn.MSELoss().cuda()

    params, multiple = get_parameters(model, config, False)

    optimizer = torch.optim.SGD(params,
                                config.base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_list = [AverageMeter() for i in range(6)]
    end = time.time()
    iters = config.start_iters
    best_model = config.best_model

    heat_weight = 46 * 46 * 15 / 1.0

    while iters < config.max_iter:
        #train_loader가 한번 불러오면 i는 1증가, input은 16개씩 가져옴
        for i, (input, heatmap, centermap,
                img_path) in enumerate(train_loader):

            learning_rate = adjust_learning_rate(
                optimizer,
                iters,
                config.base_lr,
                policy=config.lr_policy,
                policy_parameter=config.policy_parameter,
                multiple=multiple)
            data_time.update(time.time() - end)

            heatmap = heatmap.cuda(async=True)
            #print(heatmap)
            #sys.exit(1)
            centermap = centermap.cuda(async=True)

            input_var = torch.autograd.Variable(input)
            heatmap_var = torch.autograd.Variable(heatmap)
            centermap_var = torch.autograd.Variable(centermap)

            heat1, heat2, heat3, heat4, heat5, heat6 = model(
                input_var, centermap_var)

            loss1 = criterion(heat1, heatmap_var) * heat_weight
            loss2 = criterion(heat2, heatmap_var) * heat_weight
            loss3 = criterion(heat3, heatmap_var) * heat_weight
            loss4 = criterion(heat4, heatmap_var) * heat_weight
            loss5 = criterion(heat5, heatmap_var) * heat_weight
            loss6 = criterion(heat6, heatmap_var) * heat_weight

            loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            losses.update(loss.data[0], input.size(0))
            for cnt, l in enumerate([loss1, loss2, loss3, loss4, loss5,
                                     loss6]):
                losses_list[cnt].update(l.data[0], input.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            iters += 1
            #print(i,'\n')
            if iters % config.display == 0:
                print(
                    'Train Iteration: {0}\t'
                    'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {2}\n'
                    'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                        iters,
                        config.display,
                        learning_rate,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses))
                for cnt in range(0, 6):
                    print(
                        'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'.
                        format(cnt + 1, loss1=losses_list[cnt]))

                print(
                    time.strftime(
                        '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                        time.localtime()))
                #############    image write  ##################
                for cnt in range(config.batch_size):
                    kpts = get_kpts(heat6[cnt], img_h=368.0, img_w=368.0)
                    draw_paint(img_path[cnt], kpts, i, cnt)
                #######################################################
                batch_time.reset()
                data_time.reset()
                losses.reset()
                for cnt in range(6):
                    losses_list[cnt].reset()

            save_checkpoint({
                'iter': iters,
                'state_dict': model.state_dict(),
            }, 0, args.model_name)

            # val
            if args.val_dir is not None and config.test_interval != 0 and iters % config.test_interval == 0:

                model.eval()
                for j, (input, heatmap, centermap) in enumerate(val_loader):
                    heatmap = heatmap.cuda(async=True)
                    centermap = centermap.cuda(async=True)

                    input_var = torch.autograd.Variable(input)
                    heatmap_var = torch.autograd.Variable(heatmap)
                    centermap_var = torch.autograd.Variable(centermap)

                    heat1, heat2, heat3, heat4, heat5, heat6 = model(
                        input_var, centermap_var)

                    loss1 = criterion(heat1, heatmap_var) * heat_weight
                    loss2 = criterion(heat2, heatmap_var) * heat_weight
                    loss3 = criterion(heat3, heatmap_var) * heat_weight
                    loss4 = criterion(heat4, heatmap_var) * heat_weight
                    loss5 = criterion(heat5, heatmap_var) * heat_weight
                    loss6 = criterion(heat6, heatmap_var) * heat_weight

                    loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
                    losses.update(loss.data[0], input.size(0))
                    for cnt, l in enumerate(
                        [loss1, loss2, loss3, loss4, loss5, loss6]):
                        losses_list[cnt].update(l.data[0], input.size(0))

                    batch_time.update(time.time() - end)
                    end = time.time()
                    is_best = losses.avg < best_model
                    best_model = min(best_model, losses.avg)
                    save_checkpoint(
                        {
                            'iter': iters,
                            'state_dict': model.state_dict(),
                        }, is_best, args.model_name)

                    if j % config.display == 0:
                        print(
                            'Test Iteration: {0}\t'
                            'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                            'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                            'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.
                            format(j,
                                   config.display,
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses))
                        for cnt in range(0, 6):
                            print(
                                'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'
                                .format(cnt + 1, loss1=losses_list[cnt]))

                        print(
                            time.strftime(
                                '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                                time.localtime()))
                        batch_time.reset()
                        losses.reset()
                        for cnt in range(6):
                            losses_list[cnt].reset()

                model.train()
Exemple #23
0
def train(epochs, model, optimizer, train_loader, save_path, device=None):
    ''' Train MDNRNN.

    Parameters
    ----------
    epochs : int
        The number of epochs for which to train the model.

    model : torch.nn.Module
        The model to train.

    optimizer : torch.optim.Optimizer
        The optimizer to use in training the model.

    train_loader : torch.utils.data.DataLoader
        The data loader to use in training.

    save_path : Union[str, pathlib.Path]
       The path to which to save the model.

    device : Optional[str]
        The device to use for training, or `None` to auto-detect whether CUDA can be used.
    '''
    device = device if device else 'cuda' if torch.cuda.is_available(
    ) else 'cpu'
    model = model.to(device).train()
    batch_count = len(train_loader)
    train_iters = batch_count * epochs

    for epoch in range(epochs):
        hidden = model.init_hidden(batch_size, device)

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs = inputs.to(device)  # shape (999, 1, 1, 35)
            targets = targets.to(device)  # shape (999, 1, 1, 32)

            optimizer.zero_grad()

            hidden = model.init_hidden(inputs.size(1), device)

            y, (h, c) = model.lstm(inputs.view(-1, 1, 35), hidden)
            h = h.detach()
            c = c.detach()
            log_pi, mu, sigma = model.get_mixture(y)
            l = mdnrnn_loss(log_pi, mu, sigma, targets.view(-1, 1, 32))
            l.backward()
            optimizer.step()

            # data logging parameters
            percent_comp = ((batch_idx + 1) +
                            (epoch * batch_count)) / train_iters
            time = datetime.datetime.now().time()
            date = datetime.datetime.now().date()
            df = pd.DataFrame(
                [[batch_idx, percent_comp,
                  l.item(), time, date]],
                columns=['batch_idx', '%_comp', 'batch_loss', 'time', 'date'])

            # printing progress
            if batch_idx % print_interval == 0:
                print('percent complete: ' + str(percent_comp) + ', batch: ' +
                      str(batch_idx) + ', loss: ' + str(l))

            # logging progress
            if batch_idx % save_interval == 0:
                # Save training log
                with open(save_path + '.txt', 'a') as f:
                    record = df.to_json(orient='records')
                    f.write(record)
                    f.write(os.linesep)

                # Save model and optimizer state dicts
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, save_path, True)
Exemple #24
0
def binary_learning(train_loader, network, criterion, test_loader, optimizer,
                    start_epoch, lr_scheduler):
    vis = visdom.Visdom()
    r_loss = []
    r_average_f1 = []
    iterations = []
    epochs = []
    total_iteration = 0

    options = dict(legend=['loss'])
    loss_plot = vis.line(Y=np.zeros(1), X=np.zeros(1), opts=options)
    options = dict(legend=['average_f1'])
    average_f1_plot = vis.line(Y=np.zeros(1), X=np.zeros(1), opts=options)

    for epoch in range(start_epoch,
                       params.number_of_epochs_for_metric_learning):

        print('current_learning_rate =', optimizer.param_groups[0]['lr'], ' ',
              datetime.datetime.now())

        i = 0
        for data in train_loader:
            i = i + 1
            inputs, labels = data
            # print('inputs ', inputs) # batch_size x 3 x 64 x 64
            # we need pairs of images in our batch
            # print('inputs, labels ', labels)
            # and +1/-1 labels matrix

            labels_matrix = utils.get_labels_matrix_fast(labels,
                                                         labels).view(-1, 1)

            indices_for_loss = get_indices_for_loss(labels_matrix,
                                                    negative_pair_sign=0)

            labels_matrix = labels_matrix[indices_for_loss]
            labels_matrix = Variable(labels_matrix).cuda()
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            # here we should create input pair for the network from just inputs
            outputs = network(Variable(inputs).cuda())

            outputs = outputs[indices_for_loss.cuda(), :]
            # print('outputs ', outputs)
            # print('labels_matrix.long().view(-1, 1).squeeze() ', labels_matrix.long().view(-1, 1).squeeze())
            loss = criterion(outputs,
                             labels_matrix.long().view(-1, 1).squeeze())

            loss.backward()
            optimizer.step()

            # print statistics
            current_batch_loss = loss.data[0]

            if i % 10 == 0:  # print every 2000 mini-batches
                print('[epoch %d, iteration in the epoch %5d] loss: %.30f' %
                      (epoch + 1, i + 1, current_batch_loss))
                # print('PCA matrix ', network.spoc.PCA_matrix)

                r_loss.append(current_batch_loss)
                iterations.append(total_iteration + i)

                options = dict(legend=['loss'])
                loss_plot = vis.line(Y=np.array(r_loss),
                                     X=np.array(iterations),
                                     win=loss_plot,
                                     opts=options)

        lr_scheduler.step(epoch=epoch, metrics=current_batch_loss)

        if epoch % 10 == 0:
            epochs.append(epoch)
            # print the quality metric
            gc.collect()

            print('Evaluation on train internal', datetime.datetime.now())
            average_f1 = test.test_for_binary_classification_1_batch(
                train_loader, network)
            r_average_f1.append(average_f1)
            options = dict(legend=['average_f1'])
            average_f1_plot = vis.line(Y=np.array(r_average_f1),
                                       X=np.array(epochs),
                                       win=average_f1_plot,
                                       opts=options)

            print('Evaluation on test internal', datetime.datetime.now())
            average_f1 = test.test_for_binary_classification_1_batch(
                test_loader, network)

            utils.save_checkpoint(
                network=network,
                optimizer=optimizer,
                filename=params.
                name_prefix_for_saved_model_for_binary_classification + '-%d' %
                (epoch),
                epoch=epoch)
        total_iteration = total_iteration + i

    print('Finished Training for binary classification')
Exemple #25
0
def train(args, model, optimizer,criterion, dataloader_train,dataloader_train_val, dataloader_val):
    comments=os.getcwd().split('/')[-1]
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_dir = os.path.join(args.log_dirs, comments+'_'+current_time + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)

    step = 0
    best_pred=0.0
    for epoch in range(args.num_epochs):
        lr = u.adjust_learning_rate(args,optimizer,epoch) 
        model.train()
        # if epoch>=args.train_val_epochs:
        #     dataloader_train=dataloader_train_val
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        train_loss=0.0
#        is_best=False
        for i,(data, label) in enumerate(dataloader_train):
            # if i>len(dataloader_train)-2:
            #     break
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()
            optimizer.zero_grad()
            aux_out,main_out = model(data)
            # get weight_map
            weight_map=torch.zeros(args.num_classes)
            weight_map=weight_map.cuda()
            for ind in range(args.num_classes):
                weight_map[ind]=1/(torch.sum((label==ind).float())+1.0)
            # print(weight_map)

            loss_aux=F.nll_loss(main_out,label,weight=None)
            loss_main= criterion[1](main_out, label)

            loss =loss_main+loss_aux
            loss.backward()
            optimizer.step()
            tq.update(args.batch_size)
            train_loss += loss.item()
            tq.set_postfix(loss='%.6f' % (train_loss/(i+1)))
            step += 1
            if step%10==0:
                writer.add_scalar('Train/loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('Train/loss_epoch', float(loss_train_mean), epoch)
        print('loss for train : %f' % (loss_train_mean))
        

        if epoch % args.validation_step == 0:
            Dice1,Dice2,Dice3,Dice4= val(args, model, dataloader_val)
            writer.add_scalar('Valid/Dice1_val', Dice1, epoch)
            writer.add_scalar('Valid/Dice2_val', Dice2, epoch)
            writer.add_scalar('Valid/Dice3_val', Dice3, epoch)
            writer.add_scalar('Valid/Dice4_val', Dice4, epoch)
           
            mean_Dice=(Dice1+Dice2+Dice3+Dice4)/4.0
            is_best=mean_Dice > best_pred
            best_pred = max(best_pred, mean_Dice)
            checkpoint_dir = args.save_model_path
            # checkpoint_dir=os.path.join(checkpoint_dir_root,str(k_fold))
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            checkpoint_latest =os.path.join(checkpoint_dir, 'checkpoint_latest.pth.tar')
            u.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_dice': best_pred,
                    }, best_pred,epoch,is_best, checkpoint_dir,filename=checkpoint_latest)
Exemple #26
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    datapath = './data/ModelNet/'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sModelNet40-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + 'train_%s_cls.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(
        datapath, classification=True)
    logger.info("The number of training data is: %d", train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    trainDataset = ModelNetDataLoader(train_data, train_label)
    testDataset = ModelNetDataLoader(test_data, test_label)
    trainDataLoader = torch.utils.data.DataLoader(trainDataset,
                                                  batch_size=args.batchsize,
                                                  shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False)
    '''MODEL LOADING'''
    num_class = 40
    classifier = PointConvClsSsg(num_class).cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=30,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'
    '''TRANING'''
    logger.info('Start training...')
    first_time = True
    for epoch in range(start_epoch, args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):', global_epoch + 1, epoch + 1,
                    args.epoch)

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            #construct_planes(points[0])
            optimizer.zero_grad()
            classifier = classifier.train()
            pred = classifier(points)
            loss = F.nll_loss(pred, target.long())

            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = test(classifier.eval(), trainDataLoader,
                         False) if args.train_metric else None
        acc = test(classifier, testDataLoader, False)

        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print(
            '\r Test %s: %f   ***  %s: %f' %
            (blue('Accuracy'), acc, blue('Best Accuracy'), best_tst_accuracy))
        logger.info('Test Accuracy: %f  *** Best Test Accuracy: %f', acc,
                    best_tst_accuracy)

        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(global_epoch + 1,
                            train_acc if args.train_metric else 0.0, acc,
                            classifier, optimizer, str(checkpoints_dir),
                            args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f' % best_tst_accuracy)

    logger.info('End of training...')
Exemple #27
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(args, weights_init)

    # set grow controller
    grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2, args.grow_step3)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init)

        start_search_iter = checkpoint['search_iter']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        controller.load_state_dict(checkpoint['ctrl_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ctrl_optimizer.load_state_dict(checkpoint['ctrl_optimizer'])
        prev_archs = checkpoint['prev_archs']
        prev_hiddens = checkpoint['prev_hiddens']

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})')
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0
        controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage+3))
    train_loader = dataset.train

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'controller_steps': start_search_iter * args.ctrl_step
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop
    for search_iter in tqdm(range(int(start_search_iter), int(args.max_search_iter)), desc='search progress'):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter in [args.grow_step1, args.grow_step2, args.grow_step3]:

            # save
            cur_stage = grow_ctrler.cur_stage(search_iter)
            logger.info(f'=> grow to stage {cur_stage}')
            prev_archs, prev_hiddens = get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens)

            # grow section
            del controller
            del ctrl_optimizer
            controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init)

            dataset = datasets.ImageDataset(args, 2 ** (cur_stage + 3))
            train_loader = dataset.train

        dynamic_reset = train_shared(args, gen_net, dis_net, g_loss_history, d_loss_history, controller, gen_optimizer,
                                     dis_optimizer, train_loader, prev_hiddens=prev_hiddens, prev_archs=prev_archs)
        train_controller(args, controller, ctrl_optimizer, gen_net, prev_hiddens, prev_archs, writer_dict)

        if dynamic_reset:
            logger.info('re-initialize share GAN')
            del gen_net, dis_net, gen_optimizer, dis_optimizer
            gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(args, weights_init)

        save_checkpoint({
            'cur_stage': cur_stage,
            'search_iter': search_iter + 1,
            'gen_model': args.gen_model,
            'dis_model': args.dis_model,
            'controller': args.controller,
            'gen_state_dict': gen_net.state_dict(),
            'dis_state_dict': dis_net.state_dict(),
            'ctrl_state_dict': controller.state_dict(),
            'gen_optimizer': gen_optimizer.state_dict(),
            'dis_optimizer': dis_optimizer.state_dict(),
            'ctrl_optimizer': ctrl_optimizer.state_dict(),
            'prev_archs': prev_archs,
            'prev_hiddens': prev_hiddens,
            'path_helper': args.path_helper
        }, False, args.path_helper['ckpt_path'])

    final_archs, _ = get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens)
    logger.info(f"discovered archs: {final_archs}")
Exemple #28
0
def train_and_evaluate(model,
                       train_dataloader,
                       val_dataloader,
                       optimizer,
                       loss_fn,
                       metrics,
                       params,
                       model_dir,
                       logger,
                       restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) - name of file to restore from (without its extension .pth.tar)
    """

    best_val_acc = 0.0
    # reload weights from restore_file if specified
    if restore_file is not None:
        logging.info("Restoring parameters from {}".format(restore_file))
        checkpoint = utils.load_checkpoint(restore_file, model, optimizer)
        params.start_epoch = checkpoint['epoch']

        best_val_acc = checkpoint['best_val_acc']
        print('best_val_acc=', best_val_acc)
        print(optimizer.state_dict()['param_groups'][0]['lr'],
              checkpoint['epoch'])

    # learning rate schedulers for different models:
    if params.lr_decay_type == None:
        logging.info("no lr decay")
    else:
        assert params.lr_decay_type in ['multistep', 'exp', 'plateau']
        logging.info("lr decay:{}".format(params.lr_decay_type))
    if params.lr_decay_type == 'multistep':
        scheduler = MultiStepLR(optimizer,
                                milestones=params.lr_step,
                                gamma=params.scheduler_gamma,
                                last_epoch=params.start_epoch - 1)

    elif params.lr_decay_type == 'exp':
        scheduler = ExponentialLR(optimizer,
                                  gamma=params.scheduler_gamma2,
                                  last_epoch=params.start_epoch - 1)
    elif params.lr_decay_type == 'plateau':
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=params.scheduler_gamma3,
                                      patience=params.patience,
                                      verbose=False,
                                      threshold=0.0001,
                                      threshold_mode='rel',
                                      cooldown=0,
                                      min_lr=0,
                                      eps=1e-08)

    for epoch in range(params.start_epoch, params.num_epochs):
        params.current_epoch = epoch
        if params.lr_decay_type != 'plateau':
            scheduler.step()

        # Run one epoch
        logger.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train_metrics, train_confusion_meter = train(model, optimizer, loss_fn,
                                                     train_dataloader, metrics,
                                                     params, logger)

        # Evaluate for one epoch on validation set
        val_metrics, val_confusion_meter, _ = evaluate(model, loss_fn,
                                                       val_dataloader, metrics,
                                                       params, logger)

        # vis logger
        accs = [
            100. * (1 - train_metrics['accuracytop1']),
            100. * (1 - train_metrics['accuracytop5']),
            100. * (1 - val_metrics['accuracytop1']),
            100. * (1 - val_metrics['accuracytop5']),
        ]
        error_logger15.log([epoch] * 4, accs)

        losses = [train_metrics['loss'], val_metrics['loss']]
        loss_logger.log([epoch] * 2, losses)
        train_confusion_logger.log(train_confusion_meter.value())
        test_confusion_logger.log(val_confusion_meter.value())

        # log split loss
        if epoch == params.start_epoch:
            loss_key = []
            for key in [k for k, v in train_metrics.items()]:
                if 'ls' in key: loss_key.append(key)
            loss_split_key = ['train_' + k for k in loss_key
                              ] + ['val_' + k for k in loss_key]
            loss_logger_split.opts['legend'] = loss_split_key

        loss_split = [train_metrics[k]
                      for k in loss_key] + [val_metrics[k] for k in loss_key]
        loss_logger_split.log([epoch] * len(loss_split_key), loss_split)

        if params.lr_decay_type == 'plateau':
            scheduler.step(val_metrics['ls_all'])

        val_acc = val_metrics['accuracytop1']
        is_best = val_acc >= best_val_acc
        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc
            },
            epoch=epoch + 1,
            is_best=is_best,
            save_best_ever_n_epoch=params.save_best_ever_n_epoch,
            checkpointpath=params.experiment_path + '/checkpoint',
            start_epoch=params.start_epoch)

        val_metrics['best_epoch'] = epoch + 1
        # If best_eval, best_save_path, metric
        if is_best:
            logger.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(params.experiment_path,
                                          "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(params.experiment_path,
                                      "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
Exemple #29
0
def main():    
    opt = TrainOptions().parse() # 参考 options 文件夹   
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu")
    
    # logging
    visualizer = Visualizer(opt)  
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png')
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss', 'train/enhance_loss', 'val/enhance_loss'], 'loss.png')
     
    # data
    logging.info("Building dataset.")
    train_dataset = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),) 
    val_dataset   = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),)
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) 
    train_loader  = MixSequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler)
    val_loader    = MixSequentialDataLoader(val_dataset, batch_size=int(opt.batch_size/2), num_workers=opt.num_workers, shuffle=False)
    opt.idim = train_dataset.get_feat_size() 
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")
    
    # Setup an model
    lr = opt.lr                            # learning rate
    eps = opt.eps                          # Epsilon constant for optimizer
    iters = opt.iters                      # manual iters number (useful on restarts)
    best_acc = opt.best_acc                # best_acc
    best_loss = opt.best_loss              # best_loss
    start_epoch = opt.start_epoch          # manual iters number (useful on restarts)
    
    enhance_model_path = None
    # path to latest checkpoint (default: none)
    if opt.enhance_resume: 
        enhance_model_path = os.path.join(opt.works_dir, opt.enhance_resume)
        if os.path.isfile(enhance_model_path):
            enhance_model = EnhanceModel.load_model(enhance_model_path, 'enhance_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(enhance_model_path))     
    
    asr_model_path = None
    # path to latest checkpoint (default: none)
    if opt.asr_resume:
        asr_model_path = os.path.join(opt.works_dir, opt.asr_resume)
        if os.path.isfile(asr_model_path):
            asr_model = ShareE2E.load_model(asr_model_path, 'asr_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(asr_model_path))  
                                        
    joint_model_path = None
    # path to latest checkpoint (default: none)
    if opt.joint_resume:
        joint_model_path = os.path.join(opt.works_dir, opt.joint_resume)
        if os.path.isfile(joint_model_path):
            package = torch.load(joint_model_path, map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)  
            best_acc = package.get('best_acc', 0)      
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))   
            iters = int(package.get('iters', 0)) - 1   
            print('joint_model_path {} and iters {}'.format(joint_model_path, iters))        
            ##loss_report = package.get('loss_report', loss_report)
            ##visualizer.set_plot_report(loss_report, 'loss.png')
        else:
            print("no checkpoint found at {}".format(joint_model_path))
    if joint_model_path is not None or enhance_model_path is None:     
        enhance_model = EnhanceModel.load_model(joint_model_path, 'enhance_state_dict', opt)    
    if joint_model_path is not None or asr_model_path is None:  
        asr_model = ShareE2E.load_model(joint_model_path, 'asr_state_dict', opt)     
    feat_model = FbankModel.load_model(joint_model_path, 'fbank_state_dict', opt)
    # NOTE: isGAN has never occurced in example!!!!!!! 
    if opt.isGAN:
        gan_model = GANModel.load_model(joint_model_path, 'gan_state_dict', opt) 
    ##set_requires_grad([enhance_model], False)    
    
    # Setup an optimizer
    enhance_parameters = filter(lambda p: p.requires_grad, enhance_model.parameters())
    asr_parameters = filter(lambda p: p.requires_grad, asr_model.parameters())
    if opt.isGAN:
        gan_parameters = filter(lambda p: p.requires_grad, gan_model.parameters())   
    # Optimizer
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters, rho=0.95, eps=eps)
        asr_optimizer = torch.optim.Adadelta(asr_parameters, rho=0.95, eps=eps)
        if opt.isGAN:
            gan_optimizer = torch.optim.Adadelta(gan_parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters, lr=lr, betas=(opt.beta1, 0.999))   
        asr_optimizer = torch.optim.Adam(asr_parameters, lr=lr, betas=(opt.beta1, 0.999)) 
        if opt.isGAN:                      
            gan_optimizer = torch.optim.Adam(gan_parameters, lr=lr, betas=(opt.beta1, 0.999))
    if opt.isGAN:
        criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device)
       
    # Training	
    enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model) 
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate)  
    sche_samp_rate = sample_rampup.update(iters)
    
    enhance_model.train()
    feat_model.train()
    asr_model.train()               	                    
    for epoch in range(start_epoch, opt.epochs):               
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)  
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes) 
            enhance_feat = feat_model(enhance_out)
            clean_feat = feat_model(clean_inputs)
            mix_feat = feat_model(mix_inputs)
            if opt.enhance_loss_type == 'L2':
                enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'L1':
                enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'smooth_L1':
                enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
            enhance_loss = opt.enhance_loss_lambda * enhance_loss
                
            loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, sche_samp_rate, enhance_cmvn) 
            coral_loss = opt.coral_loss_lambda * CORAL(clean_context, mix_context)              
            asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
            loss = asr_loss + enhance_loss + coral_loss
                    
            if opt.isGAN:
                set_requires_grad([gan_model], False)
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                else:
                    fake_AB = enhance_feat
                gan_loss = opt.gan_loss_lambda * criterionGAN(gan_model(fake_AB, enhance_cmvn), True)
                loss += gan_loss
                                              
            enhance_optimizer.zero_grad()
            asr_optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()          
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()
                asr_optimizer.step()                
            
            if opt.isGAN:
                set_requires_grad([gan_model], True)   
                gan_optimizer.zero_grad()
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                    real_AB = torch.cat((mix_feat, clean_feat), 2)
                else:
                    fake_AB = enhance_feat
                    real_AB = clean_feat
                loss_D_real = criterionGAN(gan_model(real_AB.detach(), enhance_cmvn), True)
                loss_D_fake = criterionGAN(gan_model(fake_AB.detach(), enhance_cmvn), False)
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(gan_model.parameters(), opt.grad_clip)
                if math.isnan(grad_norm):
                    logging.warning('grad norm is nan. Do not update model.')
                else:
                    gan_optimizer.step()
                                               
            iters += 1
            errors = {'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 
                      'train/acc': acc, 'train/loss_att': loss_att.item(), 
                      'train/enhance_loss': enhance_loss.item(), 'train/coral_loss': coral_loss.item()}
            if opt.isGAN:
                errors['train/loss_D'] = loss_D.item()
                errors['train/gan_loss'] = opt.gan_loss_lambda * gan_loss.item()  
              
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'fbank_state_dict': feat_model.state_dict(), 
                         'enhance_state_dict': enhance_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                filename='latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                    
            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(iters, sche_samp_rate))    
                enhance_model.eval() 
                feat_model.eval() 
                asr_model.eval()
                torch.set_grad_enabled(False)                
                num_saved_attention = 0 
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes)                         
                    enhance_feat = feat_model(enhance_out)
                    clean_feat = feat_model(clean_inputs)
                    mix_feat = feat_model(mix_inputs)
                    if opt.enhance_loss_type == 'L2':
                        enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'L1':
                        enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'smooth_L1':
                        enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
                    if opt.isGAN:
                        set_requires_grad([gan_model], False)
                        if opt.netD_type == 'pixel':
                            fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                        else:
                            fake_AB = enhance_feat
                        gan_loss = criterionGAN(gan_model(fake_AB, enhance_cmvn), True)
                        enhance_loss += opt.gan_loss_lambda * gan_loss
                        
                    loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, 0.0, enhance_cmvn)
                                                  
                    asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
                    enhance_loss = opt.enhance_loss_lambda * enhance_loss
                    loss = asr_loss + enhance_loss                          
                    errors = {'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 
                              'val/acc': acc, 'val/loss_att': loss_att.item(),
                              'val/enhance_loss': enhance_loss.item()}
                    if opt.isGAN:        
                        errors['val/gan_loss'] = opt.gan_loss_lambda * gan_loss.item()  
                    visualizer.set_current_errors(errors)
                
                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(enhance_feat, targets, input_sizes, target_sizes, enhance_cmvn)                            
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x]) 
                                visualizer.plot_attention(att_w, dec_len, enc_len, file_name) 
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:   
                                    break 
                enhance_model.train()
                feat_model.train()
                asr_model.train() 
                torch.set_grad_enabled(True)  
				
                visualizer.print_epoch_errors(epoch, iters)  
                acc_report = visualizer.plot_epoch_errors(epoch, iters, 'acc.png') 
                loss_report = visualizer.plot_epoch_errors(epoch, iters, 'loss.png') 
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc') 
                filename = None                
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(asr_optimizer, opt.eps_decay)
                    else:
                        filename='model.acc.best'                    
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))  
                elif args.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(asr_optimizer, opt.eps_decay)
                    else:
                        filename='model.loss.best'    
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))                  
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'fbank_state_dict': feat_model.state_dict(), 
                         'enhance_state_dict': enhance_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                utils.save_checkpoint(state, opt.exp_path, filename=filename)                  
                visualizer.reset()  
                enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model)  
Exemple #30
0
def main():
    env = gym.make(args.env_name)
    env.seed(args.seed)
    torch.manual_seed(args.seed)

    num_inputs = env.observation_space.shape[0]
    num_actions = env.action_space.shape[0]
    running_state = ZFilter((num_inputs,), clip=5)

    print('state size:', num_inputs) 
    print('action size:', num_actions)

    actor = Actor(num_inputs, num_actions, args)
    critic = Critic(num_inputs, args)

    actor_optim = optim.Adam(actor.parameters(), lr=args.learning_rate)
    critic_optim = optim.Adam(critic.parameters(), lr=args.learning_rate, 
                              weight_decay=args.l2_rate)

    writer = SummaryWriter(comment="-ppo_iter-" + str(args.max_iter_num))
    
    if args.load_model is not None:
        saved_ckpt_path = os.path.join(os.getcwd(), 'save_model', str(args.load_model))
        ckpt = torch.load(saved_ckpt_path)

        actor.load_state_dict(ckpt['actor'])
        critic.load_state_dict(ckpt['critic'])

        running_state.rs.n = ckpt['z_filter_n']
        running_state.rs.mean = ckpt['z_filter_m']
        running_state.rs.sum_square = ckpt['z_filter_s']

        print("Loaded OK ex. Zfilter N {}".format(running_state.rs.n))

    
    episodes = 0    

    for iter in range(args.max_iter_num):
        actor.eval(), critic.eval()
        memory = deque()

        steps = 0
        scores = []

        while steps < args.total_sample_size: 
            state = env.reset()
            score = 0

            state = running_state(state)
            
            for _ in range(10000): 
                if args.render:
                    env.render()

                steps += 1

                mu, std = actor(torch.Tensor(state).unsqueeze(0))
                action = get_action(mu, std)[0]
                next_state, reward, done, _ = env.step(action)

                if done:
                    mask = 0
                else:
                    mask = 1

                memory.append([state, action, reward, mask])

                next_state = running_state(next_state)
                state = next_state

                score += reward

                if done:
                    break
            
            episodes += 1
            scores.append(score)
        
        score_avg = np.mean(scores)
        print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg))
        writer.add_scalar('log/score', float(score_avg), iter)

        actor.train(), critic.train()
        train_model(actor, critic, memory, actor_optim, critic_optim, args)

        if iter % 100:
            score_avg = int(score_avg)

            model_path = os.path.join(os.getcwd(),'save_model')
            if not os.path.isdir(model_path):
                os.makedirs(model_path)

            ckpt_path = os.path.join(model_path, 'ckpt_'+ str(score_avg)+'.pth.tar')

            save_checkpoint({
                'actor': actor.state_dict(),
                'critic': critic.state_dict(),
                'z_filter_n':running_state.rs.n,
                'z_filter_m': running_state.rs.mean,
                'z_filter_s': running_state.rs.sum_square,
                'args': args,
                'score': score_avg
            }, filename=ckpt_path)
Exemple #31
0
def main(
    args,
    vis,
):
    best_seen = -1
    best_harmonic = -1
    best_epoch = -1
    best_unseen = -1
    print(str(args))
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # cudnn.benchmark = True
    if vis is not None:
        vis.text(str(args), win="args")
    rnn_model = None
    fv = None
    best_harmonic = 0
    torch.autograd.set_detect_anomaly(True)
    if args.att in ["rnn", "lstm", "gru", "affine", "fc", "lstm2"]:
        #rnn_model, rnn_loader, best_loss, best_epoch = train_rnn(args,vis)
        rnn_loader, _ = get_data_rnn(args)
        if args.att == "rnn":
            rnn_model = torch.nn.RNN(300, 300, 1)
        elif args.att in ["lstm", "lstm2"]:
            rnn_model = torch.nn.LSTM(300, 300, 1)
        elif args.att == "gru":
            rnn_model = torch.nn.GRU(300, 300, 1)
        elif args.att == "affine":
            rnn_model = Affine(word_embed_size=300)
        elif args.att == "fc":
            rnn_model = FC(word_embed_size=300)
        elif args.att == "fcb":
            rnn_model = FC(word_embed_size=300, bias=True)

        rnn_model = rnn_model.cuda()
        checkpoint = load_checkpoint(
            osp.join(
                "./models/",
                args.att + "_" + args.rnn_cost + '_checkpoint_best.pth.tar'))
        rnn_model.load_state_dict(checkpoint['state_dict'])
        rnn_model = rnn_model.cuda()
    else:
        rnn_loader, _ = get_data_rnn(args)
        if args.att == "fisher":
            fv = FisherVector(rnn_loader.dataset.get_all_embeddings(30000),
                              args.kmeansk)
            fv.train()
    print("Rnn Model: ")
    print(rnn_model)

    train_loader, val_loader, seen_loader, unseen_loader = get_data(
        args, rnn_loader.dataset.is_in)
    print("Got data")
    train_embedding_matrix = get_em(args,
                                    train_loader,
                                    rnn_loader=rnn_loader,
                                    model=rnn_model,
                                    fv=fv,
                                    mode="train").cuda()
    val_embedding_matrix = get_em(args,
                                  val_loader,
                                  rnn_loader=rnn_loader,
                                  model=rnn_model,
                                  fv=fv,
                                  mode="val").cuda()
    all_embedding_matrix = get_em(args,
                                  val_loader,
                                  rnn_loader=rnn_loader,
                                  model=rnn_model,
                                  fv=fv,
                                  mode="all").cuda()
    unseen_embedding_matrix = get_em(args,
                                     val_loader,
                                     rnn_loader=rnn_loader,
                                     model=rnn_model,
                                     fv=fv,
                                     mode="unseen").cuda()
    if not args.joint:
        if rnn_model is not None:
            rnn_model = rnn_model.cpu()
        rnn_loader = None
        rnn_model = None
    print("got embeddings")

    model = ALE(train_embedding_matrix,
                img_embed_size=train_loader.dataset.get_image_embed_size(),
                dropout=args.dropout,
                batch_size=args.batch_size)
    model = model.cuda()
    if args.joint:
        checkpoint = load_checkpoint(
            osp.join(
                "./models/",
                str(False) + "_" + args.att + "_" + args.cost +
                '_checkpoint_best.pth.tar'))
        model.load_state_dict(checkpoint['state_dict'])
        model = model.cuda()

    #model = nn.DataParallel(model).cuda()
    param_groups = model.parameters()

    if args.cost == "ALE":
        criterion = ale_loss
    elif args.cost == "CEL":
        print("Using cross-entrophy loss")
        criterion = torch.nn.CrossEntropyLoss().cuda()
    elif args.cost == "WARP":
        criterion = WARPLoss()
    else:
        assert False, "Unknown cost function"

    if args.joint:
        if args.rnn_cost == "MSE":
            rnn_criterion = torch.nn.MSELoss()
        elif args.rnn_cost == "COS":
            coss_loss = torch.nn.CosineEmbeddingLoss()  #margin can be added
            rnn_criterion = lambda x, y: coss_loss(x, y,
                                                   torch.ones((x.shape[0])))
        else:
            assert False, "Unknown rnn cost function"
    """
    optimizer = torch.optim.SGD(param_groups, lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    """
    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.wd,
                                 amsgrad=True)
    if args.joint and args.att in [
            "rnn", "lstm", "gru", "affine", "fc", "lstm2"
    ]:
        rnn_optimizer = torch.optim.Adam(
            rnn_model.parameters()
            if args.att != "random" else [train_embedding_matrix],
            lr=args.joint_lr,
            weight_decay=args.joint_lr,
            amsgrad=True)

    def adjust_lr(epoch):
        if epoch != 0 and epoch in [args.epochs // 3, 2 * args.epochs // 3]:

            for g in optimizer.param_groups:
                g['lr'] *= 0.1
                print('=====> adjust lr to {}'.format(g['lr']))
            if args.joint and args.att in [
                    "rnn", "lstm", "gru", "affine", "fc", "lstm2"
            ]:
                for g in rnn_optimizer.param_groups:
                    g['lr'] *= 0.1
                    print('=====> adjust lr to {}'.format(g['lr']))

    best_val_pc = -1
    bar = Bar('Training', max=len(train_loader))
    for epoch in range(0, args.epochs):
        adjust_lr(epoch)
        model.set_embedding(train_embedding_matrix)
        model.train()
        perclass_accuracies = torch.zeros(
            (train_embedding_matrix.shape[0])).cuda()
        if rnn_model is not None:
            rnn_model.train()
        loss = AverageMeter()
        acc1 = AverageMeter()
        acc5 = AverageMeter()

        for i, d in enumerate(train_loader):
            img_embeds, metas = d
            img_embeds = img_embeds.cuda()
            optimizer.zero_grad()
            if args.joint and args.att in [
                    "rnn", "lstm", "gru", "affine", "fc", "lstm2", "fcb"
            ]:
                rnn_optimizer.zero_grad()
            comps = model(img_embeds)
            classes = metas["class"].cuda()
            loss_value = criterion(comps, classes)
            loss_value.backward()
            loss.update(loss_value.item(), img_embeds.size(0))
            optimizer.step()

            if args.joint and args.att in [
                    "rnn", "lstm", "gru", "affine", "fc", "lstm2", "random",
                    "fcb"
            ]:
                rnn_optimizer.step()
                train_embedding_matrix = get_em(args,
                                                train_loader,
                                                rnn_loader=rnn_loader,
                                                model=rnn_model,
                                                fv=fv,
                                                mode="train").cuda()
                model.set_embedding(train_embedding_matrix)

            acc1_train = top1_acc(classes, comps)
            acc5_train = top5_acc(classes, comps)
            top1_acc_perclass(classes, comps, perclass_accuracies)
            acc1.update(acc1_train, img_embeds.size(0))
            acc5.update(acc5_train, img_embeds.size(0))
            # plot progress
            bar.suffix = 'Epoch: [{}][{}/{}]\t {}\t Loss  {:.6f}\t Acc1 {:.3f}\t Acc5 {:.3f}\t'.format(
                epoch, (i + 1), len(train_loader), args.att, loss.avg,
                acc1.avg, acc5.avg)
            bar.next()
        bar.finish()
        if args.joint and args.att in [
                "rnn", "lstm", "gru", "affine", "fc", "lstm2"
        ] and args.att != "random":
            train_rnn_tick(rnn_loader, args, model, rnn_optimizer,
                           rnn_criterion, epoch)
            rnn_model.eval()
            val_embedding_matrix = get_em(args,
                                          val_loader,
                                          rnn_loader=rnn_loader,
                                          model=rnn_model,
                                          fv=fv,
                                          mode="val").cuda()
            all_embedding_matrix = get_em(args,
                                          val_loader,
                                          rnn_loader=rnn_loader,
                                          model=rnn_model,
                                          fv=fv,
                                          mode="all").cuda()
            unseen_embedding_matrix = get_em(args,
                                             val_loader,
                                             rnn_loader=rnn_loader,
                                             model=rnn_model,
                                             fv=fv,
                                             mode="unseen").cuda()
        accpc = calc_perclass(perclass_accuracies,
                              train_loader.dataset.train_sample_per_class,
                              "training")
        print("Train Accpc: %f" % accpc)

        acc1_val, acc5_val, accpc_val, loss_val = test(val_loader,
                                                       args,
                                                       em=val_embedding_matrix,
                                                       model=model,
                                                       criterion=criterion)
        if vis is not None:
            zsl_acc, zsl_acc_seen, zsl_acc_unseen = evaluate(
                args,
                eval_func,
                args.dset,
                all_embedding_matrix,
                unseen_embedding_matrix,
                model=model)
            zsl_harmonic = 2 * (zsl_acc_seen * zsl_acc_unseen) / (
                zsl_acc_seen + zsl_acc_unseen)
            print("Harmonic: %.6f" % zsl_harmonic)
            print("------")
            ##### PLOTS
            #Loss
            draw_vis(vis=vis,
                     title="Loss",
                     name="train",
                     epoch=epoch,
                     value=loss.avg,
                     legend=['train', 'val'])
            draw_vis(vis=vis,
                     title="Loss",
                     name="val",
                     epoch=epoch,
                     value=loss_val,
                     legend=['train', 'val'])
            #acc1
            draw_vis(vis=vis,
                     title="Acc1",
                     name="train",
                     epoch=epoch,
                     value=acc1.avg,
                     legend=['train', 'val'])
            draw_vis(vis=vis,
                     title="Acc1",
                     name="val",
                     epoch=epoch,
                     value=acc1_val,
                     legend=['train', 'val'])
            #acc5
            draw_vis(vis=vis,
                     title="Acc5",
                     name="train",
                     epoch=epoch,
                     value=acc5.avg,
                     legend=['train', 'val'])
            draw_vis(vis=vis,
                     title="Acc5",
                     name="val",
                     epoch=epoch,
                     value=acc5_val,
                     legend=['train', 'val'])
            #accperclass
            draw_vis(vis=vis,
                     title="Accpc",
                     name="train",
                     epoch=epoch,
                     value=accpc,
                     legend=['train', 'val'])
            draw_vis(vis=vis,
                     title="Accpc",
                     name="val",
                     epoch=epoch,
                     value=accpc_val,
                     legend=['train', 'val'])
            #testing
            draw_vis(vis=vis,
                     title="Testing",
                     name="zsl_acc",
                     epoch=epoch,
                     value=zsl_acc,
                     legend=['zsl_acc', 'seen', 'unseen', 'harmonic'])
            draw_vis(vis=vis,
                     title="Testing",
                     name="seen",
                     epoch=epoch,
                     value=zsl_acc_seen,
                     legend=['zsl_acc', 'seen', 'unseen', 'harmonic'])
            draw_vis(vis=vis,
                     title="Testing",
                     name="unseen",
                     epoch=epoch,
                     value=zsl_acc_unseen,
                     legend=['zsl_acc', 'seen', 'unseen', 'harmonic'])
            draw_vis(vis=vis,
                     title="Testing",
                     name="harmonic",
                     epoch=epoch,
                     value=zsl_harmonic,
                     legend=['zsl_acc', 'seen', 'unseen', 'harmonic'])
            ##########
        key = str(args.joint) + "_" + args.att + "_" + args.cost
        with open(key + "pclog.txt", "w") as filem:
            filem.write(str(perclass_accs_global))
        if vis is None and accpc_val > best_val_pc + best_val_pc / 100:
            #zsl_acc, zsl_acc_seen, zsl_acc_unseen = evaluate(args,eval_func,args.dset, all_embedding_matrix, unseen_embedding_matrix, model=model)
            #zsl_harmonic = 2*( zsl_acc_seen * zsl_acc_unseen ) / ( zsl_acc_seen + zsl_acc_unseen )
            #print("Harmonic: %.6f" % zsl_harmonic)
            with open("ale_results.pc", "rb") as filem:
                ale_results = pc.load(filem)
            key = str(args.joint) + "_" + args.att + "_" + args.cost

            if key not in ale_results or ale_results[key][
                    "best_valpc"] + ale_results[key][
                        "best_valpc"] / 100 < accpc_val:
                best_val_pc = accpc_val
                _, _, accpc_seen, _ = test(seen_loader,
                                           args,
                                           em=val_embedding_matrix,
                                           model=model,
                                           criterion=criterion,
                                           strm="Seen")
                _, _, accpc_unseen, _ = test(unseen_loader,
                                             args,
                                             em=val_embedding_matrix,
                                             model=model,
                                             criterion=criterion,
                                             strm="Unseen")
                test_harmonic = 2 * (accpc_seen * accpc_unseen) / (
                    accpc_seen + accpc_unseen)
                print("Test Harmonic: %.6f" % test_harmonic)
                print("------")
                best_harmonic = test_harmonic
                best_seen = accpc_seen
                best_unseen = accpc_unseen
                best_epoch = epoch
                print("FOUND NEW BEST: ", key)
                with open("what_changed.txt", "w") as filem:
                    filem.write("Found new best: %s\n" % key)
                with open(key + "pclog.txt", "w") as filem:
                    filem.write(str(perclass_accs_global))
                ale_results[key] = {}
                ale_results[key]["args"] = str(args)
                ale_results[key]["best_valpc"] = accpc_val
                ale_results[key]["best_epoch"] = best_epoch
                ale_results[key]["best_harmonic"] = best_harmonic
                ale_results[key]["best_seen"] = best_seen
                ale_results[key]["best_unseen"] = best_unseen
                with open("ale_results.pc", "wb") as filem:
                    pc.dump(ale_results, filem)
                save_checkpoint(
                    {
                        'state_dict': model.state_dict(),
                        'epoch': epoch,
                    },
                    False,
                    fpath=osp.join("./models/",
                                   key + '_checkpoint_best.pth.tar'))
                with open("./models/" + key + '_checkpoint_best.txt',
                          "w") as filem:
                    filem.write(str(args))
        print("------")

    if vis is None:
        return best_val_pc, best_harmonic, best_seen, best_unseen, best_epoch