diff --git a/src/opnsense/scripts/netflow/lib/aggregate.py b/src/opnsense/scripts/netflow/lib/aggregate.py index 11d675f54..5b0ebdcc9 100644 --- a/src/opnsense/scripts/netflow/lib/aggregate.py +++ b/src/opnsense/scripts/netflow/lib/aggregate.py @@ -355,60 +355,61 @@ class BaseFlowAggregator(object): :return: iterator returning dict records (start_time, end_time, [fields], octets, packets) """ result = list() - select_fields = self._valid_fields(fields) - filter_fields = [] - query_params = {} - if value_field == 'octets': - value_sql = 'sum(octets)' - elif value_field == 'packets': - value_sql = 'sum(packets)' - else: - value_sql = '0' + if self.is_db_open() and 'timeserie' in self._known_targets: + select_fields = self._valid_fields(fields) + filter_fields = [] + query_params = {} + if value_field == 'octets': + value_sql = 'sum(octets)' + elif value_field == 'packets': + value_sql = 'sum(packets)' + else: + value_sql = '0' - # query filters, correct start_time for resolution - query_params['start_time'] = self._parse_timestamp((int(start_time/self.resolution))*self.resolution) - query_params['end_time'] = self._parse_timestamp(end_time) - if data_filters: - for data_filter in data_filters.split(','): - tmp = data_filter.split('=')[0].strip() - if tmp in self.agg_fields and data_filter.find('=') > -1: - filter_fields.append(tmp) - query_params[tmp] = '='.join(data_filter.split('=')[1:]) + # query filters, correct start_time for resolution + query_params['start_time'] = self._parse_timestamp((int(start_time/self.resolution))*self.resolution) + query_params['end_time'] = self._parse_timestamp(end_time) + if data_filters: + for data_filter in data_filters.split(','): + tmp = data_filter.split('=')[0].strip() + if tmp in self.agg_fields and data_filter.find('=') > -1: + filter_fields.append(tmp) + query_params[tmp] = '='.join(data_filter.split('=')[1:]) - if len(select_fields) > 0: - # construct sql query to filter and select data - sql_select = 'select %s' % ','.join(select_fields) - sql_select += ', %s as total, max(last_seen) last_seen \n' % value_sql - sql_select += 'from timeserie \n' - sql_select += 'where mtime >= :start_time and mtime < :end_time\n' - for filter_field in filter_fields: - sql_select += ' and %s = :%s \n' % (filter_field, filter_field) - sql_select += 'group by %s\n'% ','.join(select_fields) - sql_select += 'order by %s desc ' % value_sql + if len(select_fields) > 0: + # construct sql query to filter and select data + sql_select = 'select %s' % ','.join(select_fields) + sql_select += ', %s as total, max(last_seen) last_seen \n' % value_sql + sql_select += 'from timeserie \n' + sql_select += 'where mtime >= :start_time and mtime < :end_time\n' + for filter_field in filter_fields: + sql_select += ' and %s = :%s \n' % (filter_field, filter_field) + sql_select += 'group by %s\n'% ','.join(select_fields) + sql_select += 'order by %s desc ' % value_sql - # execute select query - cur = self._db_connection.cursor() - cur.execute(sql_select, query_params) + # execute select query + cur = self._db_connection.cursor() + cur.execute(sql_select, query_params) - # fetch all data, to a max of [max_hits] rows. - field_names = (map(lambda x:x[0], cur.description)) - for record in cur.fetchall(): - result_record = dict() - for field_indx in range(len(field_names)): - if len(record) > field_indx: - result_record[field_names[field_indx]] = record[field_indx] - if len(result) < max_hits: - result.append(result_record) - else: - if len(result) == max_hits: - # generate row for "rest of data" - result.append({'total': 0}) - for key in result_record: - if key not in result[-1]: - result[-1][key] = "" - result[-1]['total'] += result_record['total'] - # close cursor - cur.close() + # fetch all data, to a max of [max_hits] rows. + field_names = (map(lambda x:x[0], cur.description)) + for record in cur.fetchall(): + result_record = dict() + for field_indx in range(len(field_names)): + if len(record) > field_indx: + result_record[field_names[field_indx]] = record[field_indx] + if len(result) < max_hits: + result.append(result_record) + else: + if len(result) == max_hits: + # generate row for "rest of data" + result.append({'total': 0}) + for key in result_record: + if key not in result[-1]: + result[-1][key] = "" + result[-1]['total'] += result_record['total'] + # close cursor + cur.close() return result @@ -418,25 +419,26 @@ class BaseFlowAggregator(object): :param end_time: end timestamp :return: iterator """ - query_params = {} - query_params['start_time'] = self._parse_timestamp((int(start_time/self.resolution))*self.resolution) - query_params['end_time'] = self._parse_timestamp(end_time) - sql_select = 'select mtime start_time, ' - sql_select += '%s, octets, packets, last_seen as "last_seen [timestamp]" \n' % ','.join(self.agg_fields) - sql_select += 'from timeserie \n' - sql_select += 'where mtime >= :start_time and mtime < :end_time\n' - cur = self._db_connection.cursor() - cur.execute(sql_select, query_params) + if self.is_db_open() and 'timeserie' in self._known_targets: + query_params = {} + query_params['start_time'] = self._parse_timestamp((int(start_time/self.resolution))*self.resolution) + query_params['end_time'] = self._parse_timestamp(end_time) + sql_select = 'select mtime start_time, ' + sql_select += '%s, octets, packets, last_seen as "last_seen [timestamp]" \n' % ','.join(self.agg_fields) + sql_select += 'from timeserie \n' + sql_select += 'where mtime >= :start_time and mtime < :end_time\n' + cur = self._db_connection.cursor() + cur.execute(sql_select, query_params) - # fetch all data, to a max of [max_hits] rows. - field_names = (map(lambda x:x[0], cur.description)) - while True: - record = cur.fetchone() - if record is None: - break - else: - result_record=dict() - for field_indx in range(len(field_names)): - if len(record) > field_indx: - result_record[field_names[field_indx]] = record[field_indx] - yield result_record + # fetch all data, to a max of [max_hits] rows. + field_names = (map(lambda x:x[0], cur.description)) + while True: + record = cur.fetchone() + if record is None: + break + else: + result_record=dict() + for field_indx in range(len(field_names)): + if len(record) > field_indx: + result_record[field_names[field_indx]] = record[field_indx] + yield result_record