Esempio n. 1
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-m',
                        '--mode',
                        type=str,
                        required=True,
                        help='either "train" or "test"')
    parser.add_argument('-n',
                        '--model',
                        type=str,
                        help='path to a trained model')
    parser.add_argument('-d',
                        '--data',
                        type=str,
                        help='path to financial data from Yahoo Finance')
    args = parser.parse_args()

    if args.mode == 'test':
        assert args.model is not None
        assert args.data is not None
        test(args.data, args.model)

    elif args.mode == 'train':
        assert args.data is not None
        train(args.data)

    elif args.mode == 'general':
        assert args.data is None
        train_general()
Esempio n. 2
0
def train_src(model, params, src_data_loader, tgt_data_loader,
              tgt_data_loader_eval, device, logger):
    """Train dann."""
    ####################
    # 1. setup network #
    ####################

    # setup criterion and optimizer

    if not params.finetune_flag:
        print("training non-office task")
        optimizer = optim.SGD(model.parameters(),
                              lr=params.lr,
                              momentum=params.momentum,
                              weight_decay=params.weight_decay)
    else:
        print("training office task")
        parameter_list = [{
            "params": model.features.parameters(),
            "lr": 0.001
        }, {
            "params": model.fc.parameters(),
            "lr": 0.001
        }, {
            "params": model.bottleneck.parameters()
        }, {
            "params": model.classifier.parameters()
        }, {
            "params": model.discriminator.parameters()
        }]
        optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9)

    criterion = nn.CrossEntropyLoss()

    ####################
    # 2. train network #
    ####################
    global_step = 0
    for epoch in range(params.num_epochs):
        # set train state for Dropout and BN layers
        model.train()
        # zip source and target data pair
        len_dataloader = min(len(src_data_loader), len(tgt_data_loader))
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
        for step, ((images_src, class_src), (images_tgt, _)) in data_zip:

            p = float(step + epoch * len_dataloader) / \
                params.num_epochs / len_dataloader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            if params.lr_adjust_flag == 'simple':
                lr = adjust_learning_rate(optimizer, p)
            else:
                lr = adjust_learning_rate_office(optimizer, p)
            logger.add_scalar('lr', lr, global_step)

            # prepare domain label
            size_src = len(images_src)
            size_tgt = len(images_tgt)

            # make images variable
            class_src = class_src.to(device)
            images_src = images_src.to(device)

            # zero gradients for optimizer
            model.zero_grad()

            # train on source domain
            src_class_output, src_domain_output = model(input_data=images_src,
                                                        alpha=alpha)
            src_loss_class = criterion(src_class_output, class_src)

            loss = src_loss_class

            # optimize dann
            loss.backward()
            optimizer.step()

            global_step += 1

            # print step info
            logger.add_scalar('loss', loss.item(), global_step)

            if ((step + 1) % params.log_step == 0):
                print("Epoch [{:4d}/{}] Step [{:2d}/{}]: loss={:.6f}".format(
                    epoch + 1, params.num_epochs, step + 1, len_dataloader,
                    loss.data.item()))

        # eval model
        if ((epoch + 1) % params.eval_step == 0):
            src_test_loss, src_acc, src_acc_domain = test(model,
                                                          src_data_loader,
                                                          device,
                                                          flag='source')
            tgt_test_loss, tgt_acc, tgt_acc_domain = test(model,
                                                          tgt_data_loader_eval,
                                                          device,
                                                          flag='target')
            logger.add_scalar('src_test_loss', src_test_loss, global_step)
            logger.add_scalar('src_acc', src_acc, global_step)

        # save model parameters
        if ((epoch + 1) % params.save_step == 0):
            save_model(
                model, params.model_root, params.src_dataset + '-' +
                params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))

    # save final model
    save_model(
        model, params.model_root,
        params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")

    return model
Esempio n. 3
0
        from config.classic_control import muzero_config  # just using same config as classic_control for now
    elif args.case == 'classic_control':
        from config.classic_control import muzero_config
    else:
        raise Exception('Invalid --case option')

    # set config as per arguments
    exp_path = muzero_config.set_config(args)
    exp_path, log_base_path = make_results_dir(exp_path, args)

    # set-up logger
    init_logger(log_base_path)

    try:
        if args.opr == 'train':
            summary_writer = SummaryWriter(exp_path, flush_secs=10)
            train(muzero_config, summary_writer)

        elif args.opr == 'test':
            assert os.path.exists(muzero_config.model_path), 'model not found at {}'.format(muzero_config.model_path)
            model = muzero_config.get_uniform_network().to('cpu')
            model.load_state_dict(torch.load(muzero_config.model_path, map_location=torch.device('cpu')))
            test_score = test(muzero_config, model, args.test_episodes, device='cpu', render=args.render,
                              save_video=True)
            logging.getLogger('test').info('Test Score: {}'.format(test_score))
        else:
            raise Exception('Please select a valid operation(--opr) to be performed')
        ray.shutdown()
    except Exception as e:
        logging.getLogger('root').error(e, exc_info=True)
Esempio n. 4
0
            model = model.to('cpu')
            model.load_state_dict(
                torch.load(model_path, map_location=torch.device('cpu')))

            if args.render and args.case == 'mujoco':
                # Ref: https://github.com/openai/mujoco-py/issues/390
                from mujoco_py import GlfwContext

                GlfwContext(offscreen=True)

            env = run_config.new_game()
            test_score, test_repeat_counts = test(
                env,
                model,
                args.test_episodes,
                device='cpu',
                render=args.render,
                save_test_data=True,
                save_path=run_config.test_data_path,
                recording_path=run_config.recording_path)
            env.close()

            logging.getLogger('test').info('Test Score: {}'.format(test_score))
        else:
            raise ValueError(
                '"--opr {}" is not implemented ( or not valid)'.format(
                    args.opr))

    except Exception as e:
        logging.getLogger('root').error(e, exc_info=True)
Esempio n. 5
0
def train(cfg, init_epoch, dataset_loader, train_transforms, val_transforms,
          deblurnet, deblurnet_solver, deblurnet_lr_scheduler, ckpt_dir,
          train_writer, val_writer, Best_Img_PSNR, Best_Epoch):

    n_itr = 0
    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Set up data loader
        train_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(
                utils.data_loaders.DatasetType.TRAIN, train_transforms),
            batch_size=cfg.CONST.TRAIN_BATCH_SIZE,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=True)

        # Tick / tock
        epoch_start_time = time()
        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        deblur_mse_losses = utils.network_utils.AverageMeter()
        if cfg.TRAIN.USE_PERCET_LOSS == True:
            deblur_percept_losses = utils.network_utils.AverageMeter()
        deblur_losses = utils.network_utils.AverageMeter()
        img_PSNRs = utils.network_utils.AverageMeter()

        # Adjust learning rate
        deblurnet_lr_scheduler.step()
        print('[INFO] learning rate: {0}\n'.format(
            deblurnet_lr_scheduler.get_lr()))

        batch_end_time = time()
        seq_num = len(train_data_loader)

        vggnet = VGG19()
        if torch.cuda.is_available():
            vggnet = torch.nn.DataParallel(vggnet).cuda()

        for seq_idx, (_, seq_blur, seq_clear) in enumerate(train_data_loader):
            # Measure data time
            data_time.update(time() - batch_end_time)
            # Get data from data loader
            seq_blur = [
                utils.network_utils.var_or_cuda(img) for img in seq_blur
            ]
            seq_clear = [
                utils.network_utils.var_or_cuda(img) for img in seq_clear
            ]

            # switch models to training mode
            deblurnet.train()

            # Train the model
            last_img_blur = seq_blur[0]
            output_last_img = seq_blur[0]
            output_last_fea = None
            for batch_idx, [img_blur,
                            img_clear] in enumerate(zip(seq_blur, seq_clear)):
                img_blur_hold = img_blur
                output_img, output_fea = deblurnet(img_blur, last_img_blur,
                                                   output_last_img,
                                                   output_last_fea)

                # deblur loss
                deblur_mse_loss = mseLoss(output_img, img_clear)
                deblur_mse_losses.update(deblur_mse_loss.item(),
                                         cfg.CONST.TRAIN_BATCH_SIZE)
                if cfg.TRAIN.USE_PERCET_LOSS == True:
                    deblur_percept_loss = perceptualLoss(
                        output_img, img_clear, vggnet)
                    deblur_percept_losses.update(deblur_percept_loss.item(),
                                                 cfg.CONST.TRAIN_BATCH_SIZE)
                    deblur_loss = deblur_mse_loss + 0.01 * deblur_percept_loss
                else:
                    deblur_loss = deblur_mse_loss
                deblur_losses.update(deblur_loss.item(),
                                     cfg.CONST.TRAIN_BATCH_SIZE)
                img_PSNR = PSNR(output_img, img_clear)
                img_PSNRs.update(img_PSNR.item(), cfg.CONST.TRAIN_BATCH_SIZE)

                # deblurnet update
                deblurnet_solver.zero_grad()
                deblur_loss.backward()
                deblurnet_solver.step()

                # Append loss to TensorBoard
                train_writer.add_scalar('STFANet/DeblurLoss_0_TRAIN',
                                        deblur_loss.item(), n_itr)
                train_writer.add_scalar('STFANet/DeblurMSELoss_0_TRAIN',
                                        deblur_mse_loss.item(), n_itr)
                if cfg.TRAIN.USE_PERCET_LOSS == True:
                    train_writer.add_scalar(
                        'STFANet/DeblurPerceptLoss_0_TRAIN',
                        deblur_percept_loss.item(), n_itr)
                n_itr = n_itr + 1

                # Tick / tock
                batch_time.update(time() - batch_end_time)
                batch_end_time = time()

                # print per batch
                if (batch_idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
                    if cfg.TRAIN.USE_PERCET_LOSS == True:
                        print(
                            '[TRAIN] [Ech {0}/{1}][Seq {2}/{3}][Bch {4}/{5}] BT {6} DT {7} DeblurLoss {8} [{9}, {10}] PSNR {11}'
                            .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES,
                                    seq_idx + 1, seq_num, batch_idx + 1,
                                    cfg.DATA.SEQ_LENGTH, batch_time, data_time,
                                    deblur_losses, deblur_mse_losses,
                                    deblur_percept_losses, img_PSNRs))
                    else:
                        print(
                            '[TRAIN] [Ech {0}/{1}][Seq {2}/{3}][Bch {4}/{5}] BT {6} DT {7} DeblurLoss {8} PSNR {9}'
                            .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES,
                                    seq_idx + 1, seq_num, batch_idx + 1,
                                    cfg.DATA.SEQ_LENGTH, batch_time, data_time,
                                    deblur_losses, img_PSNRs))

                # show
                if seq_idx == 0 and batch_idx < cfg.TEST.VISUALIZATION_NUM:
                    img_blur = img_blur[0][[2, 1, 0], :, :].cpu(
                    ) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
                    img_clear = img_clear[0][[2, 1, 0], :, :].cpu(
                    ) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
                    output_last_img = output_last_img[0][[2, 1, 0], :, :].cpu(
                    ) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
                    img_out = output_img[0][[2, 1, 0], :, :].cpu().clamp(
                        0.0, 1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)

                    result = torch.cat([
                        torch.cat([img_blur, img_clear], 2),
                        torch.cat([output_last_img, img_out], 2)
                    ], 1)
                    result = torchvision.utils.make_grid(result,
                                                         nrow=1,
                                                         normalize=True)
                    train_writer.add_image(
                        'STFANet/TRAIN_RESULT' + str(batch_idx + 1), result,
                        epoch_idx + 1)

                # *** Update output_last_img/feature ***
                last_img_blur = img_blur_hold
                output_last_img = output_img.clamp(0.0, 1.0).detach()
                output_last_fea = output_fea.detach()

            # print per sequence
            print('[TRAIN] [Epoch {0}/{1}] [Seq {2}/{3}] ImgPSNR_avg {4}\n'.
                  format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, seq_idx + 1,
                         seq_num, img_PSNRs.avg))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar('STFANet/EpochPSNR_0_TRAIN', img_PSNRs.avg,
                                epoch_idx + 1)

        # Tick / tock
        epoch_end_time = time()
        print('[TRAIN] [Epoch {0}/{1}]\t EpochTime {2}\t ImgPSNR_avg {3}\n'.
              format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES,
                     epoch_end_time - epoch_start_time, img_PSNRs.avg))

        # Validate the training models
        img_PSNR = test(cfg, epoch_idx, dataset_loader, val_transforms,
                        deblurnet, val_writer)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth.tar' % (epoch_idx + 1)), \
                                                      epoch_idx + 1, deblurnet, deblurnet_solver, \
                                                      Best_Img_PSNR, Best_Epoch)
        if img_PSNR >= Best_Img_PSNR:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            Best_Img_PSNR = img_PSNR
            Best_Epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(os.path.join(ckpt_dir, 'best-ckpt.pth.tar'), \
                                                      epoch_idx + 1, deblurnet, deblurnet_solver, \
                                                      Best_Img_PSNR, Best_Epoch)

    # Close SummaryWriter for TensorBoard
    train_writer.close()
    val_writer.close()
Esempio n. 6
0
	def run(self):
		core.test(self.urlSync, self.spliderThreadPool, self.mutex)
		
Esempio n. 7
0
File: build.py Progetto: sczhou/IGNN
def bulid_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data augmentation
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(cfg.DATA.CROP_IMG_SIZE,
                                         cfg.CONST.SCALE),
        utils.data_transforms.FlipRotate(),
        utils.data_transforms.BGR2RGB(),
        utils.data_transforms.RandomColorChannel(),
        # utils.data_transforms.ColorJitter(cfg.DATA.COLOR_JITTER),
        # utils.data_transforms.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD),
        # utils.data_transforms.RandomGaussianNoise(cfg.DATA.GAUSSIAN),
        utils.data_transforms.ToTensor()
    ])

    test_transforms = utils.data_transforms.Compose([
        # utils.data_transforms.BorderCrop(cfg.CONST.SCALE),
        utils.data_transforms.BGR2RGB(),
        # utils.data_transforms.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD),
        utils.data_transforms.ToTensor()
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.DATASET_TRAIN_NAME](utils.data_loaders.DatasetType.TRAIN)
    test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.DATASET_TEST_NAME](utils.data_loaders.DatasetType.TEST)
    if cfg.NETWORK.PHASE in ['train', 'resume']:
        train_data_loader = torch.utils.data.DataLoader(
            dataset=train_dataset_loader.get_dataset(train_transforms),
            batch_size=cfg.CONST.TRAIN_BATCH_SIZE,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=True)
        val_data_loader = torch.utils.data.DataLoader(
            dataset=test_dataset_loader.get_dataset(test_transforms),
            batch_size=cfg.CONST.VAL_BATCH_SIZE,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=False)
    elif cfg.NETWORK.PHASE in ['test']:
        test_data_loader = torch.utils.data.DataLoader(
            dataset=test_dataset_loader.get_dataset(test_transforms),
            batch_size=cfg.CONST.TEST_BATCH_SIZE,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    net = models.__dict__[cfg.NETWORK.SRNETARCH].__dict__[
        cfg.NETWORK.SRNETARCH]()
    print('[DEBUG] %s Parameters in %s: %d.' %
          (dt.now(), cfg.NETWORK.SRNETARCH, net_utils.count_parameters(net)))

    # Initialize weights of networks
    if cfg.NETWORK.PHASE == 'train':
        net_utils.initialize_weights(net, cfg.TRAIN.KAIMING_SCALE)

    # Set up solver
    solver = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                     net.parameters()),
                              lr=cfg.TRAIN.LEARNING_RATE,
                              betas=(cfg.TRAIN.MOMENTUM, cfg.TRAIN.BETA))

    if torch.cuda.is_available():
        net = torch.nn.DataParallel(net, range(cfg.CONST.NUM_GPU)).cuda()

    # Load pretrained model if exists
    Init_Epoch = 0
    Best_Epoch = 0
    Best_PSNR = 0
    if cfg.NETWORK.PHASE in ['test', 'resume']:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)

        net.load_state_dict(checkpoint['net_state_dict'])

        if cfg.NETWORK.PHASE == 'resume': Init_Epoch = checkpoint['epoch_idx']
        Best_PSNR = checkpoint['best_PSNR']
        Best_Epoch = checkpoint['best_epoch']
        if 'solver_state_dict' in checkpoint:
            solver.load_state_dict(checkpoint['solver_state_dict'])

        print('[INFO] {0} Recover complete. Current Epoch #{1}, Best_PSNR = {2} at Epoch #{3}.' \
              .format(dt.now(), Init_Epoch, Best_PSNR, Best_Epoch))

    if cfg.NETWORK.PHASE in ['train', 'resume']:
        # Set up learning rate scheduler to decay learning rates dynamically
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            solver,
            milestones=cfg.TRAIN.LR_MILESTONES,
            gamma=cfg.TRAIN.LR_DECAY)
        # Summary writer for TensorBoard
        output_dir = os.path.join(
            cfg.DIR.OUT_PATH, 'tb_log',
            dt.now().isoformat() + '_' + cfg.NETWORK.SRNETARCH, '%s')
        log_dir = output_dir % 'logs'
        ckpt_dir = output_dir % 'checkpoints'
        train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
        val_writer = SummaryWriter(os.path.join(log_dir, 'val'))

        # train and val
        train(cfg, Init_Epoch, train_data_loader, val_data_loader, net, solver,
              lr_scheduler, ckpt_dir, train_writer, val_writer, Best_PSNR,
              Best_Epoch)
        return
    elif cfg.NETWORK.PHASE in ['test']:
        if cfg.DATASET.DATASET_TEST_NAME == 'Demo':
            test_woGT(cfg, test_data_loader, net)
        else:
            test(cfg, test_data_loader, net, Best_Epoch)
        return
Esempio n. 8
0
 def run(self):
     core.test(self.urlSync, self.spliderThreadPool, self.mutex)
Esempio n. 9
0
def bulid_net(cfg):

    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark  = True

    # Set up data augmentation
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.ColorJitter(cfg.DATA.COLOR_JITTER),
        utils.data_transforms.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD),
        utils.data_transforms.RandomCrop(cfg.DATA.CROP_IMG_SIZE),
        utils.data_transforms.RandomVerticalFlip(),
        utils.data_transforms.RandomHorizontalFlip(),
        utils.data_transforms.RandomColorChannel(),
        utils.data_transforms.RandomGaussianNoise(cfg.DATA.GAUSSIAN),
        utils.data_transforms.ToTensor(),
    ])

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.DATASET_NAME]()

    # Set up networks
    deblurnet = models.__dict__[cfg.NETWORK.DEBLURNETARCH].__dict__[cfg.NETWORK.DEBLURNETARCH]()

    print('[DEBUG] %s Parameters in %s: %d.' % (dt.now(), cfg.NETWORK.DEBLURNETARCH,
                                                utils.network_utils.count_parameters(deblurnet)))

    # Initialize weights of networks
    deblurnet.apply(utils.network_utils.init_weights_xavier)

    # Set up solver
    a =  filter(lambda p: p.requires_grad, deblurnet.parameters())
    deblurnet_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, deblurnet.parameters()), lr=cfg.TRAIN.LEARNING_RATE,
                                         betas=(cfg.TRAIN.MOMENTUM, cfg.TRAIN.BETA))

    if torch.cuda.is_available():
        deblurnet = torch.nn.DataParallel(deblurnet).cuda()

    # Load pretrained model if exists
    init_epoch       = 0
    Best_Epoch       = -1
    Best_Img_PSNR    = 0


    if cfg.NETWORK.PHASE in ['test','resume']:
        print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        deblurnet.load_state_dict(checkpoint['deblurnet_state_dict'])
        # deblurnet_solver.load_state_dict(checkpoint['deblurnet_solver_state_dict'])
        init_epoch = checkpoint['epoch_idx']+1
        Best_Img_PSNR = checkpoint['Best_Img_PSNR']
        Best_Epoch = checkpoint['Best_Epoch']
        print('[INFO] {0} Recover complete. Current epoch #{1}, Best_Img_PSNR = {2} at epoch #{3}.' \
              .format(dt.now(), init_epoch, Best_Img_PSNR, Best_Epoch))



    # Set up learning rate scheduler to decay learning rates dynamically
    deblurnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(deblurnet_solver,
                                                                   milestones=cfg.TRAIN.LR_MILESTONES,
                                                                   gamma=cfg.TRAIN.LR_DECAY)

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, dt.now().isoformat() + '_' + cfg.NETWORK.DEBLURNETARCH, '%s')
    log_dir      = output_dir % 'logs'
    ckpt_dir     = output_dir % 'checkpoints'
    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    test_writer  = SummaryWriter(os.path.join(log_dir, 'test'))
    print('[INFO] Output_dir: {0}'.format(output_dir[:-2]))

    if cfg.NETWORK.PHASE in ['train','resume']:
        train(cfg, init_epoch, dataset_loader, train_transforms, test_transforms,
                              deblurnet, deblurnet_solver, deblurnet_lr_scheduler,
                              ckpt_dir, train_writer, test_writer,
                              Best_Img_PSNR, Best_Epoch)
    else:
        if os.path.exists(cfg.CONST.WEIGHTS):
            test(cfg, init_epoch, dataset_loader, test_transforms, deblurnet, test_writer)
        else:
            print('[FATAL] %s Please specify the file path of checkpoint.' % (dt.now()))
            sys.exit(2)
Esempio n. 10
0
def train(classifier, generator, critic, src_data_loader, tgt_data_loader):
    """Train generator, classifier and critic jointly."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    classifier.train()
    generator.train()
    # set criterion for classifier and optimizers
    criterion = nn.CrossEntropyLoss()
    optimizer_c = get_optimizer(classifier, "Adam")

    # zip source and target data pair
    data_iter_src = get_inf_iterator(src_data_loader)

    # counter
    g_step = 0

    ####################
    # 2. train network #
    ####################

    for epoch in range(params.num_epochs):
        ###########################
        # 2.1 train discriminator #
        ###########################
        # requires to compute gradients for D
        for p in critic.parameters():
            p.requires_grad = True

        # set steps for discriminator
        if g_step < 25 or g_step % 500 == 0:
            # this helps to start with the critic at optimum
            # even in the first iterations.
            critic_iters = 100
        else:
            critic_iters = params.d_steps
        critic_iters = 0
        # loop for optimizing discriminator
        #for d_step in range(critic_iters):
        # convert images into torch.Variable
        images_src, labels_src = next(data_iter_src)

        images_src = make_variable(images_src).cuda()
        labels_src = make_variable(labels_src.squeeze_()).cuda()
        # print(type(images_src))

        ########################
        # 2.2 train classifier #
        ########################

        # zero gradients for optimizer
        optimizer_c.zero_grad()

        # compute loss for critic
        preds_c = classifier(generator(images_src))
        c_loss = criterion(preds_c, labels_src)

        # optimize source classifier
        c_loss.backward()
        optimizer_c.step()
        g_step += 1

        ##################
        # 2.4 print info #
        ##################
        if ((epoch + 1) % 500 == 0):
            # print("Epoch [{}/{}]:"
            #       "c_loss={:.5f}"
            #       "D(x)={:.5f}"
            #       .format(epoch + 1,
            #               params.num_epochs,
            #               c_loss.item(),
            #               ))
            test(classifier, generator, src_data_loader, params.src_dataset)
        if ((epoch + 1) % 500 == 0):
            save_model(generator, "Mnist-generator-{}.pt".format(epoch + 1))
            save_model(classifier, "Mnist-classifer{}.pt".format(epoch + 1))
Esempio n. 11
0
def train(classifier, generator, critic, src_data_loader, tgt_data_loader):
    """Train generator, classifier and critic jointly."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    classifier.train()
    generator.train()
    critic.train()

    # set criterion for classifier and optimizers
    criterion = nn.CrossEntropyLoss()
    optimizer_c = get_optimizer(classifier, "Adam")
    optimizer_g = get_optimizer(generator, "Adam")
    optimizer_d = get_optimizer(critic, "Adam")

    # zip source and target data pair
    data_iter_src = get_inf_iterator(src_data_loader)
    data_iter_tgt = get_inf_iterator(tgt_data_loader)

    # counter
    g_step = 0

    # positive and negative labels
    pos_labels = make_variable(torch.FloatTensor([1]))
    neg_labels = make_variable(torch.FloatTensor([-1]))

    ####################
    # 2. train network #
    ####################

    for epoch in range(params.num_epochs):
        ###########################
        # 2.1 train discriminator #
        ###########################
        # requires to compute gradients for D
        for p in critic.parameters():
            p.requires_grad = True
        critic_iters = 5
        for d_step in range(critic_iters):
            # convert images into torch.Variable
            images_src, labels_src = next(data_iter_src)
            images_tgt, _ = next(data_iter_tgt)
            images_src = make_variable(images_src).cuda()
            labels_src = make_variable(labels_src.squeeze_()).cuda()
            images_tgt = make_variable(images_tgt).cuda()
            if images_src.size(0) != params.batch_size or \
                    images_tgt.size(0) != params.batch_size:
                continue

            # zero gradients for optimizer
            optimizer_d.zero_grad()

            # compute source data loss for discriminator
            feat_src = generator(images_src)
            d_loss_src = critic(feat_src.detach())
            d_loss_src = d_loss_src.mean()
            d_loss_src.backward(neg_labels)

            # compute target data loss for discriminator
            feat_tgt = generator(images_tgt)
            d_loss_tgt = critic(feat_tgt.detach())
            d_loss_tgt = d_loss_tgt.mean()
            d_loss_tgt.backward(pos_labels)

            # compute gradient penalty
            gradient_penalty = calc_gradient_penalty(critic, feat_src.data,
                                                     feat_tgt.data)
            gradient_penalty.backward()

            # optimize weights of discriminator
            d_loss = -d_loss_src + d_loss_tgt + gradient_penalty
            optimizer_d.step()

        ########################
        # 2.2 train classifier #
        ########################

        # zero gradients for optimizer
        optimizer_c.zero_grad()

        # compute loss for critic
        preds_c = classifier(generator(images_src).detach())
        c_loss = criterion(preds_c, labels_src)

        # optimize source classifier
        c_loss.backward()
        optimizer_c.step()

        #######################
        # 2.3 train generator #
        #######################
        # avoid to compute gradients for D
        # zero grad for optimizer of generator
        optimizer_g.zero_grad()

        # compute source data classification loss for generator
        feat_src = generator(images_src)
        preds_c = classifier(feat_src)
        g_loss_cls = criterion(preds_c, labels_src)
        g_loss_cls.backward()

        # compute source data discriminattion loss for generator
        feat_src = generator(images_src)
        g_loss_src = critic(feat_src).mean()
        g_loss_src.backward(pos_labels)

        # compute target data discriminattion loss for generator
        feat_tgt = generator(images_tgt)
        g_loss_tgt = critic(feat_tgt).mean()
        g_loss_tgt.backward(neg_labels)

        # compute loss for generator
        g_loss = g_loss_src - g_loss_tgt + g_loss_cls

        # optimize weights of generator
        optimizer_g.step()
        g_step += 1

        ##################
        # 2.4 print info #
        ##################
        if ((epoch + 1) % params.log_step == 0):
            print("Epoch [{}/{}]:"
                  "d_loss={:.5f} c_loss={:.5f} g_loss={:.5f} "
                  "D(x)={:.5f} D(G(z))={:.5f} GP={:.5f}".format(
                      epoch + 1, params.num_epochs, d_loss.item(),
                      c_loss.item(), g_loss.item(), d_loss_src.item(),
                      d_loss_tgt.item(), gradient_penalty.item()))
            #test(classifier, generator, src_data_loader, params.src_dataset)
            print(">>> on target domain <<<")
        if ((epoch + 1) % 10 == 0):
            test(classifier, generator, tgt_data_loader, params.tgt_dataset)
        #############################
        # 2.5 save model parameters #
        #############################
        if ((epoch + 1) % params.save_step == 0):
            save_model(critic, "WGAN-GP_critic-{}.pt".format(epoch + 1))
            save_model(classifier,
                       "WGAN-GP_classifier-{}.pt".format(epoch + 1))
            save_model(generator, "WGAN-GP_generator-{}.pt".format(epoch + 1))

    return classifier, generator
Esempio n. 12
0
            raise Exception('Invalid --case option')

        muzero_config.set_game(args.env)
        muzero_config.set_exp_path(exp_path)
        muzero_config.set_device(args.device)

        if args.opr == 'train':
            train(muzero_config, summary_writer)

        elif args.opr == 'test':
            assert os.path.exists(
                muzero_config.model_path), 'model not found at {}'.format(
                    muzero_config.model_path)
            model = muzero_config.get_uniform_network().to('cpu')
            model.load_state_dict(
                torch.load(muzero_config.model_path,
                           map_location=torch.device('cpu')))
            test_score = test(muzero_config,
                              model,
                              args.test_episodes,
                              device='cpu',
                              render=args.render)
            logger.info('Test Score: {}'.format(test_score))
        else:
            raise Exception(
                'Please select a valid operation(--opr) to be performed')

        ray.shutdown()
    except Exception as e:
        logger.error(e, exc_info=True)