def main(): global args, best_loss args = parser.parse_args() model = Model.ColorNet() model.cuda() criterion = nn.MSELoss().cuda() optimizer=torch.optim.Adam(model.parameters(),lr=0.1) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.action = 'train': train_loader = dGenerator.makeDataLoader(args.data, 'train') val_loader = dGenerator.makeDataLoader(args.val, 'validate') for epoch in range(args.epoch): print('='*10+'epoch '+str(epoch)+'='*10) adjust_learning_rate(optimizer, epoch) Run.train(train_loader, model, criterion, optimizer, epoch) loss = Run.validate(val_loader, model, criterion) print('loss: '+str(loss)+'\n') is_best = loss < best_loss best_loss = min(loss, best_loss) save_checkpoint({ 'epoch': epoch + 1, 'arch': 'inception_v3', 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), }, is_best)
from Dataset.CocoDataset import * from Dataset.BoxLoader import * from Utils.RunManager import * from Utils.CheckpointLoader import * from BoxInceptionResnet import * from Dataset import Augment from Visualize import VisualizeOutput from Utils import Model from Utils import Export from tensorflow.python.client import timeline import re globalStep = tf.Variable(0, name='globalStep', trainable=False) globalStepInc = tf.assign_add(globalStep, 1) Model.download() dataset = BoxLoader() dataset.add(CocoDataset(opt.dataset, randomZoom=opt.randZoom == 1, set="train" + opt.cocoVariant)) if opt.mergeValidationSet == 1: dataset.add(CocoDataset(opt.dataset, set="val" + opt.cocoVariant)) images, boxes, classes = Augment.augment(*dataset.get()) print(f"Number of categories: {str(dataset.categoryCount())}") print(dataset.getCaptionMap()) net = BoxInceptionResnet(images, dataset.categoryCount(), name="boxnet", trainFrom=opt.trainFrom,