def valid(epoch_id, valid_loader, fetch_list, test_prog, exe): loss = utility.AvgrageMeter() top1 = utility.AvgrageMeter() top5 = utility.AvgrageMeter() for step_id, valid_data in enumerate(valid_loader()): feed = [] for device_id in range(len(valid_data)): image_val = valid_data[device_id]['image_val'] label_val = valid_data[device_id]['label_val'] # use valid data to feed image_train and label_train feed.append({ "image_train": image_val, "label_train": label_val, "image_val": image_val, "label_val": label_val }) loss_v, top1_v, top5_v = exe.run(test_prog, feed=feed, fetch_list=fetch_list) loss.update(loss_v, args.batch_size) top1.update(top1_v, args.batch_size) top5.update(top5_v, args.batch_size) if step_id % args.report_freq == 0: logger.info( "Valid Epoch {}, Step {}, loss {:.3f}, acc_1 {:.6f}, acc_5 {:.6f}" .format(epoch_id, step_id, loss.avg[0], top1.avg[0], top5.avg[0])) return top1.avg[0]
def infer(main_prog, exe, valid_reader, fetch_list, args): loss = utility.AvgrageMeter() top1 = utility.AvgrageMeter() top5 = utility.AvgrageMeter() for step_id, (image, label) in enumerate(valid_reader()): feed = {"image": image, "label": label} loss_v, top1_v, top5_v = exe.run( main_prog, feed=feed, fetch_list=[v.name for v in fetch_list]) loss.update(loss_v, args.batch_size) top1.update(top1_v, args.batch_size) top5.update(top5_v, args.batch_size) if step_id % args.report_freq == 0: logger.info("Test Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}". format(step_id, loss.avg[0], top1.avg[0], top5.avg[0])) return top1.avg[0]
def valid(main_prog, exe, epoch_id, valid_loader, fetch_list, args): loss = utility.AvgrageMeter() top1 = utility.AvgrageMeter() top5 = utility.AvgrageMeter() for step_id, data in enumerate(valid_loader()): loss_v, top1_v, top5_v = exe.run( main_prog, feed=data, fetch_list=[v.name for v in fetch_list]) loss.update(loss_v, args.batch_size) top1.update(top1_v, args.batch_size) top5.update(top5_v, args.batch_size) if step_id % args.report_freq == 0: logger.info( "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}" .format(epoch_id, step_id, loss.avg[0], top1.avg[0], top5.avg[0])) return top1.avg[0], top5.avg[0]
def train(main_prog, exe, epoch_id, train_loader, fetch_list, args): loss = utility.AvgrageMeter() top1 = utility.AvgrageMeter() top5 = utility.AvgrageMeter() for step_id, data in enumerate(train_loader()): devices_num = len(data) if args.drop_path_prob > 0: feed = [] for device_id in range(devices_num): image = data[device_id]['image'] label = data[device_id]['label'] num_cells = 4 drop_path_prob = np.array( [[args.drop_path_prob * epoch_id / args.epochs] for i in range(args.batch_size)]).astype(np.float32) drop_path_mask = 1 - np.random.binomial( 1, drop_path_prob[0], size=[args.batch_size, args.layers, num_cells, 2]).astype( np.float32) feed.append({ "image": image, "label": label, "drop_path_prob": drop_path_prob, "drop_path_mask": drop_path_mask }) else: feed = data loss_v, top1_v, top5_v, lr = exe.run( main_prog, feed=feed, fetch_list=[v.name for v in fetch_list]) loss.update(loss_v, args.batch_size) top1.update(top1_v, args.batch_size) top5.update(top5_v, args.batch_size) if step_id % args.report_freq == 0: logger.info( "Train Epoch {}, Step {}, Lr {:.8f}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}" .format(epoch_id, step_id, lr[0], loss.avg[0], top1.avg[0], top5.avg[0])) return top1.avg[0]
def train(epoch_id, train_loader, valid_loader, fetch_list, arch_progs_list, train_prog, exe): loss = utility.AvgrageMeter() top1 = utility.AvgrageMeter() top5 = utility.AvgrageMeter() for step_id, (train_data, valid_data) in enumerate(zip(train_loader(), valid_loader())): feed = [] for device_id in range(len(train_data)): feed.append(dict(train_data[device_id], **valid_data[device_id])) exe.run(arch_progs_list[0], feed=feed) exe.run(arch_progs_list[1], feed=feed) lr, loss_v, top1_v, top5_v = exe.run( train_prog, feed=feed, fetch_list=[v.name for v in fetch_list]) loss.update(loss_v, args.batch_size) top1.update(top1_v, args.batch_size) top5.update(top5_v, args.batch_size) if step_id % args.report_freq == 0: logger.info( "Train Epoch {}, Step {}, Lr {:.8f}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}" .format(epoch_id, step_id, lr[0], loss.avg[0], top1.avg[0], top5.avg[0])) return top1.avg[0]