Ejemplo n.º 1
0
    def build(self):
        if self.config['model']['type'] != 'CNN':
            raise ValueError('Model type is not valid.')

        encoder, decoder = self.config['model']['name'].upper().split('_')

        self.encoder = make_encoder(e_config[encoder])
        self.decoder = make_decoder(d_config[decoder])
Ejemplo n.º 2
0
    def build(self):
        if self.config['model']['type'] != 'RNN':
            raise ValueError('Model type is not valid.')

        self.e_name, self.d_name = self.config['model']['name'].upper().split(
            '_')

        self.encoder = make_encoder(e_config[self.e_name], self.config).cuda()
        self.decoder = make_decoder(d_config[self.d_name]).cuda()
Ejemplo n.º 3
0
def main():
    global args, best_prec1, model, train_loader, val_loaders, output, target
    args = parser.parse_args()

    logging.basicConfig(
        format="%(message)s",
        handlers=[
            # logging.FileHandler("{0}/{1}.log".format(args.log, sys.argv[0].replace('.py','') + datetime.now().strftime('_%H_%M_%d_%m_%Y'))),
            logging.FileHandler("{0}/{1}.log".format(
                args.log, sys.argv[0].replace('.py', '') + args.name)),
            logging.StreamHandler()
        ],
        level=logging.INFO)

    for sigma in range(6, 16, 2):
        # create model
        if args.pretrained:
            pass
            # logging.info("=> using pre-trained model '{}'".format(args.arch))
            # model = models.__dict__[args.arch](pretrained=True)
        else:
            model_name = CONFIG['model']['name']
            en_name = model_name.split('_')[0]
            logging.info("=> creating model '{}'".format(en_name))
            model = make_encoder(en_name)
            model._initialize_weights()
            # model.load_imagenet_weights()
            # CONFIG['train']['user'] = [8]
            # model = Encoder(CONFIG)
            # model._initialize_weights()
            # model.encoder.load_weights()

        model.cuda()

        # define loss function (criterion) and optimizer
        #criterion = nn.BCEWithLogitsLoss().cuda()
        #criterion = nn.BCELoss().cuda()
        criterion = nn.KLDivLoss()

        # optimizer = torch.optim.SGD(
        # 				list(model.encoder[1].parameters()) +
        # 				list(model.readout.parameters()),
        # 				# model.parameters(),
        # 				args.lr,
        # 				momentum=args.momentum,
        # 				weight_decay=args.weight_decay)

        # for param in model.encoder[0].parameters():
        # 	param.requires_grad = False

        optimizer = torch.optim.Adam(
            # list(model.encoder[1].parameters()) +
            # list(model.readout.parameters()),
            model.parameters(),
            args.lr,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        cudnn.benchmark = True

        if args.resume:
            if os.path.isfile(args.resume):
                logging.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                logging.info("=> no checkpoint found at '{}'".format(
                    args.resume))

        train_duration = True if (args.name == 'TDW') else False
        # eval_duration = True if (args.name.split('-')[1] == 'GTDW') else False

        print(args.name, train_duration)  #, eval_duration)

        train_dataset = Saliency()
        train_dataset.load('OSIE',
                           'train',
                           duration=train_duration,
                           split={
                               'train': 0.90,
                               'eval': 0.1
                           },
                           sigma=sigma)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)

        # val_loaders = list()

        # for name_idx, name in enumerate(CONFIG['eval']['dataset']):
        # 	if name == 'OSIE':
        # 		split={'train' :0.75 , 'eval' : 0.25}
        # 	else:
        # 		split={'train' :0.0 , 'eval' : 1.0}
        # 	val_ds = Saliency()
        # 	# val_ds.load(name, 'eval', split=split,
        # 	# 			 duration=eval_duration, sigma=sigma)
        # 	val_ds.load(name, 'eval', split=split,
        # 				 duration=eval_duration, sigma=sigma)

        # 	val_loaders.append(
        # 		torch.utils.data.DataLoader(
        # 			val_ds,
        # 			batch_size=args.batch_size * 2, shuffle=False,
        # 			num_workers=args.workers, pin_memory=True)
        # 	)

        if args.evaluate:
            validate(val_loaders, model, criterion)
            return

        # if args.visualize:
        # 	visualize(val_loader, model)
        # 	return

        for epoch in range(args.start_epoch, args.epochs):

            try:

                adjust_learning_rate(optimizer, epoch)

                # train for one epoch
                train(train_loader, model, criterion, optimizer, epoch, sigma)

                # evaluate on validation set
                # prec1 = validate(val_loaders, model, criterion, epoch)

                # remember best prec@1 and save checkpoint
                #	is_best = prec1 > best_prec1
                # best_prec1 = max(prec1, best_prec1)
                #best_prec1 = 1
                if epoch in [4, 9, 14]:
                    save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'arch': args.arch,
                            'state_dict': model.state_dict(),
                            'best_prec1': True,
                            'optimizer': optimizer.state_dict(),
                        }, True, sigma)

            except Exception as x:
                print(x)
Ejemplo n.º 4
0
def main():
    global args, best_prec1, model, train_loader, val_loaders, output, target
    args = parser.parse_args()

    logging.basicConfig(format="%(message)s",
                        handlers=[
                            logging.FileHandler("{0}/{1}.log".format(
                                args.log, sys.argv[0].replace('.py', '') +
                                datetime.now().strftime('_%H_%M_%d_%m_%Y'))),
                            logging.StreamHandler()
                        ],
                        level=logging.INFO)

    # create model
    if args.pretrained:
        pass
        # logging.info("=> using pre-trained model '{}'".format(args.arch))
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        model_name = CONFIG['model']['name']
        en_name = model_name.split('_')[0]
        logging.info("=> creating model '{}'".format(en_name))
        model = make_encoder(en_name)
        model._initialize_weights()
        # model.load_imagenet_weights()
        # CONFIG['train']['user'] = [8]
        # model = Encoder(CONFIG)
        # model._initialize_weights()
        # model.encoder.load_weights()

    model.cuda()

    # define loss function (criterion) and optimizer
    #criterion = nn.BCEWithLogitsLoss().cuda()
    criterion = nn.BCELoss().cuda()

    # optimizer = torch.optim.SGD(
    # 				list(model.encoder[1].parameters()) +
    # 				list(model.readout.parameters()),
    # 				# model.parameters(),
    # 				args.lr,
    # 				momentum=args.momentum,
    # 				weight_decay=args.weight_decay)

    # for param in model.encoder[0].parameters():
    # 	param.requires_grad = False

    optimizer = torch.optim.Adam(
        # list(model.encoder[1].parameters()) +
        # list(model.readout.parameters()),
        model.parameters(),
        args.lr,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            logging.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info("=> no checkpoint found at '{}'".format(args.resume))

    # train_duration = True if (args.name.split('-')[0] == 'TDW') else False
    # eval_duration = True if (args.name.split('-')[1] == 'GTDW') else False

    # print(args.name, train_duration, eval_duration)

    # train_dataset = Saliency()
    # train_dataset.load('OSIE', 'train', duration=train_duration)
    # train_loader = torch.utils.data.DataLoader(
    # 	train_dataset, batch_size=args.batch_size, shuffle=False,
    # 	num_workers=args.workers, pin_memory=True, sampler=None)

    # val_loaders = list()

    # for name_idx, name in enumerate(CONFIG['eval']['dataset']):
    # 	# if name == 'OSIE':
    # 	# 	split={'train' :0.75 , 'eval' : 0.25}
    # 	# else:
    # 	split={'train' :0.0 , 'eval' : 1.0}
    # 	val_ds = Saliency()
    # 	val_ds.load(name, 'eval', split=split,
    # 				 duration=eval_duration)

    # 	val_loaders.append(
    # 		torch.utils.data.DataLoader(
    # 			val_ds,
    # 			batch_size=args.batch_size * 2, shuffle=False,
    # 			num_workers=args.workers, pin_memory=True)
    # 	)

    # if args.evaluate:
    # 	validate(val_loaders, model, criterion)
    # 	return

    # if args.visualize:
    # 	visualize(val_loader, model)
    # 	return

    # policies = [
    # 	[
    # 		'TS-GTS-5.pth.tar',
    # 		'TS-GTS-10.pth.tar',
    # 		'TS-GTS-15.pth.tar'],
    # 	[
    # 		'TS-GTDW-5.pth.tar',
    # 		'TS-GTDW-10.pth.tar',
    # 		'TS-GTDW-15.pth.tar'],
    # 	[
    # 		'TDW-GTS-5.pth.tar',
    # 		'TDW-GTS-10.pth.tar',
    # 		'TDW-GTS-15.pth.tar'],

    # 	[
    # 		'TDW-GTDW-5.pth.tar',
    # 		'TDW-GTDW-10.pth.tar',
    # 		'TDW-GTDW-15.pth.tar'],
    # ]
    policies = [
        'TS-{0}-{1}.pth.tar',
        'TDW-{0}-{1}.pth.tar',
    ]

    for train_policy_idx, train_policy in enumerate(policies):

        for eval_policy_idx, eval_policy in enumerate(['GTS', 'GTDW']):

            for sigma in range(6, 15, 2):

                val_loaders = list()
                # model_name = '-'.join(policy[0].split('-')[:2])
                # eval_duration = True if (model_name.split('-')[1] == 'GTDW') else False
                # print('eval_duration: {0}'.format(eval_duration))
                # model_name = '-'.join(policy[0].split('-')[:2])
                eval_duration = True if (eval_policy == 'GTDW') else False

                for name_idx, name in enumerate(CONFIG['eval']['dataset']):
                    # if name == 'OSIE':
                    # 	split={'train' :0.75 , 'eval' : 0.25}
                    # else:

                    split = {'train': 0.0, 'eval': 1.0}
                    val_ds = Saliency()
                    val_ds.load(name,
                                'eval',
                                split=split,
                                duration=eval_duration,
                                sigma=sigma)

                    val_loaders.append(
                        torch.utils.data.DataLoader(
                            val_ds,
                            batch_size=args.batch_size * 2,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True))

                for epoch in ([15]):

                    weight = train_policy.format(epoch, sigma)
                    logging.info('eval_duration: {0} - {1}'.format(
                        weight, eval_policy))

                    try:

                        weight = os.path.join(args.weights, weight)
                        # adjust_learning_rate(optimizer, epoch)
                        model = load_check_point(model, weight)

                        prec1 = validate(
                            val_loaders, model, criterion,
                            (train_policy_idx, eval_policy_idx, sigma, epoch))

                        with open('results/{0}.npy'.format('TOTAL'),
                                  'wb') as f:
                            np.save(f, np.array(RESULTS))

                    except Exception as x:
                        print(x)
Ejemplo n.º 5
0
def main():
    global args, best_prec1, model, train_dataset, val_loader, en_name
    args = parser.parse_args()

    logging.basicConfig(format="%(message)s",
                        handlers=[
                            logging.FileHandler("{0}/{1}.log".format(
                                args.log, sys.argv[0].replace('.py', '') +
                                datetime.now().strftime('_%H_%M_%d_%m_%Y'))),
                            logging.StreamHandler()
                        ],
                        level=logging.INFO)

    # create model
    if args.pretrained:
        pass
        # logging.info("=> using pre-trained model '{}'".format(args.arch))
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        model_name = CONFIG['model']['name']
        en_name = model_name.split('_')[0]
        logging.info("=> creating model '{}'".format(en_name))
        model = make_encoder(e_config[en_name], CONFIG)
        # model._initialize_weights()

    model.cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.BCELoss().cuda()

    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    # 							momentum=args.momentum,
    # 							weight_decay=args.weight_decay)

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    cudnn.benchmark = True

    # Data loading code

    train_dataset = Saliency(CONFIG, 'train')

    if args.resume:
        if os.path.isfile(args.resume):
            logging.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info("=> no checkpoint found at '{}'".format(args.resume))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    val_loader = torch.utils.data.DataLoader(Saliency(CONFIG, 'test'),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    if args.visualize:
        visualize(val_loader, model)
        return

    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)